mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-28 06:58:53 +00:00
Merge branch 'main' into playground-ui
This commit is contained in:
commit
5119671f02
52 changed files with 938 additions and 773 deletions
|
@ -34,6 +34,7 @@ class VectorDBInput(BaseModel):
|
|||
vector_db_id: str
|
||||
embedding_model: str
|
||||
embedding_dimension: int
|
||||
provider_id: str | None = None
|
||||
provider_vector_db_id: str | None = None
|
||||
|
||||
|
||||
|
|
|
@ -338,7 +338,7 @@ class VectorIO(Protocol):
|
|||
@webmethod(route="/openai/v1/vector_stores", method="POST")
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name: str,
|
||||
name: str | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
chunking_strategy: dict[str, Any] | None = None,
|
||||
|
|
|
@ -276,8 +276,8 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
if config.external_providers_dir and not config.external_providers_dir.exists():
|
||||
config.external_providers_dir.mkdir(exist_ok=True)
|
||||
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
||||
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config])
|
||||
run_args = formulate_run_args(args.image_type, args.image_name)
|
||||
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", str(run_config)])
|
||||
run_command(run_args)
|
||||
|
||||
|
||||
|
|
|
@ -82,39 +82,6 @@ class StackRun(Subcommand):
|
|||
return ImageType.CONDA.value, args.image_name
|
||||
return args.image_type, args.image_name
|
||||
|
||||
def _resolve_config_and_template(self, args: argparse.Namespace) -> tuple[Path | None, str | None]:
|
||||
"""Resolve config file path and template name from args.config"""
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
if not args.config:
|
||||
return None, None
|
||||
|
||||
config_file = Path(args.config)
|
||||
has_yaml_suffix = args.config.endswith(".yaml")
|
||||
template_name = None
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if this is a template
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
|
||||
if config_file.exists():
|
||||
template_name = args.config
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to ~/.llama dir
|
||||
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists():
|
||||
self.parser.error(
|
||||
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
|
||||
)
|
||||
|
||||
if not config_file.is_file():
|
||||
self.parser.error(
|
||||
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
|
||||
)
|
||||
|
||||
return config_file, template_name
|
||||
|
||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||
import yaml
|
||||
|
||||
|
@ -125,8 +92,15 @@ class StackRun(Subcommand):
|
|||
self._start_ui_development_server(args.port)
|
||||
image_type, image_name = self._get_image_type_and_name(args)
|
||||
|
||||
# Resolve config file and template name first
|
||||
config_file, template_name = self._resolve_config_and_template(args)
|
||||
if args.config:
|
||||
try:
|
||||
from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template
|
||||
|
||||
config_file = resolve_config_or_template(args.config, Mode.RUN)
|
||||
except ValueError as e:
|
||||
self.parser.error(str(e))
|
||||
else:
|
||||
config_file = None
|
||||
|
||||
# Check if config is required based on image type
|
||||
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not config_file:
|
||||
|
@ -164,18 +138,14 @@ class StackRun(Subcommand):
|
|||
if callable(getattr(args, arg)):
|
||||
continue
|
||||
if arg == "config":
|
||||
if template_name:
|
||||
server_args.template = str(template_name)
|
||||
else:
|
||||
# Set the config file path
|
||||
server_args.config = str(config_file)
|
||||
server_args.config = str(config_file)
|
||||
else:
|
||||
setattr(server_args, arg, getattr(args, arg))
|
||||
|
||||
# Run the server
|
||||
server_main(server_args)
|
||||
else:
|
||||
run_args = formulate_run_args(image_type, image_name, config, template_name)
|
||||
run_args = formulate_run_args(image_type, image_name)
|
||||
|
||||
run_args.extend([str(args.port)])
|
||||
|
||||
|
|
31
llama_stack/cli/utils.py
Normal file
31
llama_stack/cli/utils.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
def add_config_template_args(parser: argparse.ArgumentParser):
|
||||
"""Add unified config/template arguments with backward compatibility."""
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
|
||||
group.add_argument(
|
||||
"config",
|
||||
nargs="?",
|
||||
help="Configuration file path or template name",
|
||||
)
|
||||
|
||||
# Backward compatibility arguments (deprecated)
|
||||
group.add_argument(
|
||||
"--config",
|
||||
dest="config",
|
||||
help="(DEPRECATED) Use positional argument [config] instead. Configuration file path",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--template",
|
||||
dest="config",
|
||||
help="(DEPRECATED) Use positional argument [config] instead. Template name",
|
||||
)
|
|
@ -214,9 +214,7 @@ class VectorIORouter(VectorIO):
|
|||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store(vector_store_id)
|
||||
return await self.routing_table.openai_retrieve_vector_store(vector_store_id)
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
|
@ -226,9 +224,7 @@ class VectorIORouter(VectorIO):
|
|||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store(
|
||||
return await self.routing_table.openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
name=name,
|
||||
expires_after=expires_after,
|
||||
|
@ -240,12 +236,7 @@ class VectorIORouter(VectorIO):
|
|||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
result = await provider.openai_delete_vector_store(vector_store_id)
|
||||
# drop from registry
|
||||
await self.routing_table.unregister_vector_db(vector_store_id)
|
||||
return result
|
||||
return await self.routing_table.openai_delete_vector_store(vector_store_id)
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
|
@ -258,9 +249,7 @@ class VectorIORouter(VectorIO):
|
|||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_search_vector_store(
|
||||
return await self.routing_table.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=filters,
|
||||
|
@ -278,9 +267,7 @@ class VectorIORouter(VectorIO):
|
|||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_attach_file_to_vector_store(
|
||||
return await self.routing_table.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
|
@ -297,9 +284,7 @@ class VectorIORouter(VectorIO):
|
|||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> list[VectorStoreFileObject]:
|
||||
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store(
|
||||
return await self.routing_table.openai_list_files_in_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
limit=limit,
|
||||
order=order,
|
||||
|
@ -314,9 +299,7 @@ class VectorIORouter(VectorIO):
|
|||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file(
|
||||
return await self.routing_table.openai_retrieve_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
@ -327,9 +310,7 @@ class VectorIORouter(VectorIO):
|
|||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
return await self.routing_table.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
@ -341,9 +322,7 @@ class VectorIORouter(VectorIO):
|
|||
attributes: dict[str, Any],
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store_file(
|
||||
return await self.routing_table.openai_update_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
|
@ -355,9 +334,7 @@ class VectorIORouter(VectorIO):
|
|||
file_id: str,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
|
||||
# Route based on vector store ID
|
||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_delete_vector_store_file(
|
||||
return await self.routing_table.openai_delete_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any
|
|||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.distribution.access_control.datatypes import Action
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AccessRule,
|
||||
RoutableObject,
|
||||
|
@ -209,6 +210,20 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
await self.dist_registry.register(obj)
|
||||
return obj
|
||||
|
||||
async def assert_action_allowed(
|
||||
self,
|
||||
action: Action,
|
||||
type: str,
|
||||
identifier: str,
|
||||
) -> None:
|
||||
"""Fetch a registered object by type/identifier and enforce the given action permission."""
|
||||
obj = await self.get_object_by_identifier(type, identifier)
|
||||
if obj is None:
|
||||
raise ValueError(f"{type.capitalize()} '{identifier}' not found")
|
||||
user = get_authenticated_user()
|
||||
if not is_action_allowed(self.policy, action, obj, user):
|
||||
raise AccessDeniedError(action, obj, user)
|
||||
|
||||
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||
|
|
|
@ -4,11 +4,24 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||
from llama_stack.apis.vector_io.vector_io import (
|
||||
SearchRankingOptions,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFileStatus,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import (
|
||||
VectorDBWithOwner,
|
||||
)
|
||||
|
@ -74,3 +87,135 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
if existing_vector_db is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
await self.unregister_object(existing_vector_db)
|
||||
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store(vector_store_id)
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
name: str | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
name=name,
|
||||
expires_after=expires_after,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||
result = await self.get_provider_impl(vector_store_id).openai_delete_vector_store(vector_store_id)
|
||||
await self.unregister_vector_db(vector_store_id)
|
||||
return result
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: SearchRankingOptions | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=filters,
|
||||
max_num_results=max_num_results,
|
||||
ranking_options=ranking_options,
|
||||
rewrite_query=rewrite_query,
|
||||
search_mode=search_mode,
|
||||
)
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> list[VectorStoreFileObject]:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_list_files_in_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
limit=limit,
|
||||
order=order,
|
||||
after=after,
|
||||
before=before,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file_contents(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any],
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_update_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_delete_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
|
|
@ -32,6 +32,7 @@ from openai import BadRequestError
|
|||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.cli.utils import add_config_template_args
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AuthenticationRequiredError,
|
||||
|
@ -53,6 +54,7 @@ from llama_stack.distribution.stack import (
|
|||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
||||
from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
@ -377,20 +379,8 @@ class ClientVersionMiddleware:
|
|||
def main(args: argparse.Namespace | None = None):
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
parser.add_argument(
|
||||
"--yaml-config",
|
||||
dest="config",
|
||||
help="(Deprecated) Path to YAML configuration file - use --config instead",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
dest="config",
|
||||
help="Path to YAML configuration file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--template",
|
||||
help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)",
|
||||
)
|
||||
|
||||
add_config_template_args(parser)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
|
@ -409,20 +399,7 @@ def main(args: argparse.Namespace | None = None):
|
|||
if args is None:
|
||||
args = parser.parse_args()
|
||||
|
||||
log_line = ""
|
||||
if hasattr(args, "config") and args.config:
|
||||
# if the user provided a config file, use it, even if template was specified
|
||||
config_file = Path(args.config)
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Config file {config_file} does not exist")
|
||||
log_line = f"Using config file: {config_file}"
|
||||
elif hasattr(args, "template") and args.template:
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Template {args.template} does not exist")
|
||||
log_line = f"Using template {args.template} config file: {config_file}"
|
||||
else:
|
||||
raise ValueError("Either --config or --template must be provided")
|
||||
config_file = resolve_config_or_template(args.config, Mode.RUN)
|
||||
|
||||
logger_config = None
|
||||
with open(config_file) as fp:
|
||||
|
@ -442,9 +419,6 @@ def main(args: argparse.Namespace | None = None):
|
|||
config = replace_env_vars(config_contents)
|
||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||
|
||||
# now that the logger is initialized, print the line about which type of config we are using.
|
||||
logger.info(log_line)
|
||||
|
||||
_log_run_config(run_config=config)
|
||||
|
||||
app = FastAPI(
|
||||
|
@ -592,12 +566,29 @@ def main(args: argparse.Namespace | None = None):
|
|||
"port": port,
|
||||
"lifespan": "on",
|
||||
"log_level": logger.getEffectiveLevel(),
|
||||
"log_config": logger_config,
|
||||
}
|
||||
if ssl_config:
|
||||
uvicorn_config.update(ssl_config)
|
||||
|
||||
# Run uvicorn in the existing event loop to preserve background tasks
|
||||
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||
# stack trace when using Ctrl+C or kill -2 (SIGINT).
|
||||
# SIGTERM (kill -15) works fine without this because Python doesn't
|
||||
# have a default handler for it.
|
||||
#
|
||||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||
# signal handling but this is quite intrusive and not worth the effort.
|
||||
try:
|
||||
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||
finally:
|
||||
if not loop.is_closed():
|
||||
logger.debug("Closing event loop")
|
||||
loop.close()
|
||||
|
||||
|
||||
def _log_run_config(run_config: StackRunConfig):
|
||||
|
|
|
@ -117,7 +117,7 @@ if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
|||
set -x
|
||||
|
||||
if [ -n "$yaml_config" ]; then
|
||||
yaml_config_arg="--config $yaml_config"
|
||||
yaml_config_arg="$yaml_config"
|
||||
else
|
||||
yaml_config_arg=""
|
||||
fi
|
||||
|
|
125
llama_stack/distribution/utils/config_resolution.py
Normal file
125
llama_stack/distribution/utils/config_resolution.py
Normal file
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="config_resolution")
|
||||
|
||||
|
||||
TEMPLATE_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "templates"
|
||||
|
||||
|
||||
class Mode(StrEnum):
|
||||
RUN = "run"
|
||||
BUILD = "build"
|
||||
|
||||
|
||||
def resolve_config_or_template(
|
||||
config_or_template: str,
|
||||
mode: Mode = Mode.RUN,
|
||||
) -> Path:
|
||||
"""
|
||||
Resolve a config/template argument to a concrete config file path.
|
||||
|
||||
Args:
|
||||
config_or_template: User input (file path, template name, or built distribution)
|
||||
mode: Mode resolving for ("run", "build", "server")
|
||||
|
||||
Returns:
|
||||
Path to the resolved config file
|
||||
|
||||
Raises:
|
||||
ValueError: If resolution fails
|
||||
"""
|
||||
|
||||
# Strategy 1: Try as file path first
|
||||
config_path = Path(config_or_template)
|
||||
if config_path.exists() and config_path.is_file():
|
||||
logger.info(f"Using file path: {config_path}")
|
||||
return config_path.resolve()
|
||||
|
||||
# Strategy 2: Try as template name (if no .yaml extension)
|
||||
if not config_or_template.endswith(".yaml"):
|
||||
template_config = _get_template_config_path(config_or_template, mode)
|
||||
if template_config.exists():
|
||||
logger.info(f"Using template: {template_config}")
|
||||
return template_config
|
||||
|
||||
# Strategy 3: Try as built distribution name
|
||||
distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml"
|
||||
if distrib_config.exists():
|
||||
logger.info(f"Using built distribution: {distrib_config}")
|
||||
return distrib_config
|
||||
|
||||
distrib_config = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml"
|
||||
if distrib_config.exists():
|
||||
logger.info(f"Using built distribution: {distrib_config}")
|
||||
return distrib_config
|
||||
|
||||
# Strategy 4: Failed - provide helpful error
|
||||
raise ValueError(_format_resolution_error(config_or_template, mode))
|
||||
|
||||
|
||||
def _get_template_config_path(template_name: str, mode: Mode) -> Path:
|
||||
"""Get the config file path for a template."""
|
||||
return TEMPLATE_DIR / template_name / f"{mode}.yaml"
|
||||
|
||||
|
||||
def _format_resolution_error(config_or_template: str, mode: Mode) -> str:
|
||||
"""Format a helpful error message for resolution failures."""
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
template_path = _get_template_config_path(config_or_template, mode)
|
||||
distrib_path = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml"
|
||||
distrib_path2 = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml"
|
||||
|
||||
available_templates = _get_available_templates()
|
||||
templates_str = ", ".join(available_templates) if available_templates else "none found"
|
||||
|
||||
return f"""Could not resolve config or template '{config_or_template}'.
|
||||
|
||||
Tried the following locations:
|
||||
1. As file path: {Path(config_or_template).resolve()}
|
||||
2. As template: {template_path}
|
||||
3. As built distribution: ({distrib_path}, {distrib_path2})
|
||||
|
||||
Available templates: {templates_str}
|
||||
|
||||
Did you mean one of these templates?
|
||||
{_format_template_suggestions(available_templates, config_or_template)}
|
||||
"""
|
||||
|
||||
|
||||
def _get_available_templates() -> list[str]:
|
||||
"""Get list of available template names."""
|
||||
if not TEMPLATE_DIR.exists() and not DISTRIBS_BASE_DIR.exists():
|
||||
return []
|
||||
|
||||
return list(
|
||||
set(
|
||||
[d.name for d in TEMPLATE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")]
|
||||
+ [d.name for d in DISTRIBS_BASE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _format_template_suggestions(templates: list[str], user_input: str) -> str:
|
||||
"""Format template suggestions for error messages, showing closest matches first."""
|
||||
if not templates:
|
||||
return " (no templates found)"
|
||||
|
||||
import difflib
|
||||
|
||||
# Get up to 3 closest matches with similarity threshold of 0.3 (lower = more permissive)
|
||||
close_matches = difflib.get_close_matches(user_input, templates, n=3, cutoff=0.3)
|
||||
display_templates = close_matches if close_matches else templates[:3]
|
||||
|
||||
suggestions = [f" - {t}" for t in display_templates]
|
||||
return "\n".join(suggestions)
|
|
@ -21,7 +21,7 @@ from pathlib import Path
|
|||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
|
||||
|
||||
def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||
def formulate_run_args(image_type: str, image_name: str) -> list[str]:
|
||||
env_name = ""
|
||||
|
||||
if image_type == LlamaStackImageType.CONDA.value:
|
||||
|
|
|
@ -10,6 +10,7 @@ import re
|
|||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
@ -911,8 +912,16 @@ async def load_data_from_url(url: str) -> str:
|
|||
|
||||
|
||||
async def get_raw_document_text(document: Document) -> str:
|
||||
if not document.mime_type.startswith("text/"):
|
||||
# Handle deprecated text/yaml mime type with warning
|
||||
if document.mime_type == "text/yaml":
|
||||
warnings.warn(
|
||||
"The 'text/yaml' MIME type is deprecated. Please use 'application/yaml' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
elif not (document.mime_type.startswith("text/") or document.mime_type == "application/yaml"):
|
||||
raise ValueError(f"Unexpected document mime type: {document.mime_type}")
|
||||
|
||||
if isinstance(document.content, URL):
|
||||
return await load_data_from_url(document.content.uri)
|
||||
elif isinstance(document.content, str):
|
||||
|
|
|
@ -128,6 +128,11 @@ class AgentPersistence:
|
|||
except Exception as e:
|
||||
log.error(f"Error parsing turn: {e}")
|
||||
continue
|
||||
|
||||
# The kvstore does not guarantee order, so we sort by started_at
|
||||
# to ensure consistent ordering of turns.
|
||||
turns.sort(key=lambda t: t.started_at)
|
||||
|
||||
return turns
|
||||
|
||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
|
||||
|
|
|
@ -224,17 +224,6 @@ def available_providers() -> list[ProviderSpec]:
|
|||
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="fireworks-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.fireworks_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksProviderDataValidator",
|
||||
description="Fireworks AI OpenAI-compatible provider for using Fireworks models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
@ -246,50 +235,6 @@ def available_providers() -> list[ProviderSpec]:
|
|||
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="together-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.together_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.together_openai_compat.config.TogetherCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.together_openai_compat.config.TogetherProviderDataValidator",
|
||||
description="Together AI OpenAI-compatible provider for using Together models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="groq-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.groq_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.groq_openai_compat.config.GroqCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.groq_openai_compat.config.GroqProviderDataValidator",
|
||||
description="Groq OpenAI-compatible provider for using Groq models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sambanova-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.sambanova_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.sambanova_openai_compat.config.SambaNovaCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.sambanova_openai_compat.config.SambaNovaProviderDataValidator",
|
||||
description="SambaNova OpenAI-compatible provider for using SambaNova models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="cerebras-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.cerebras_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.cerebras_openai_compat.config.CerebrasCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.cerebras_openai_compat.config.CerebrasProviderDataValidator",
|
||||
description="Cerebras OpenAI-compatible provider for using Cerebras models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
|
@ -395,7 +395,7 @@ That means you'll get fast and efficient vector retrieval.
|
|||
To use PGVector in your Llama Stack project, follow these steps:
|
||||
|
||||
1. Install the necessary dependencies.
|
||||
2. Configure your Llama Stack project to use Faiss.
|
||||
2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector).
|
||||
3. Start storing and querying vectors.
|
||||
|
||||
## Installation
|
||||
|
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import CerebrasCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .cerebras import CerebrasCompatInferenceAdapter
|
||||
|
||||
adapter = CerebrasCompatInferenceAdapter(config)
|
||||
return adapter
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.cerebras_openai_compat.config import CerebrasCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..cerebras.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class CerebrasCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: CerebrasCompatConfig
|
||||
|
||||
def __init__(self, config: CerebrasCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="cerebras_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class CerebrasProviderDataValidator(BaseModel):
|
||||
cerebras_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Cerebras models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CerebrasCompatConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The Cerebras API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.cerebras.ai/v1",
|
||||
description="The URL for the Cerebras API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.cerebras.ai/v1",
|
||||
"api_key": api_key,
|
||||
}
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import FireworksCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .fireworks import FireworksCompatInferenceAdapter
|
||||
|
||||
adapter = FireworksCompatInferenceAdapter(config)
|
||||
return adapter
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class FireworksProviderDataValidator(BaseModel):
|
||||
fireworks_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Fireworks models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FireworksCompatConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The Fireworks API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.fireworks.ai/inference/v1",
|
||||
description="The URL for the Fireworks API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.fireworks.ai/inference/v1",
|
||||
"api_key": api_key,
|
||||
}
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.fireworks_openai_compat.config import FireworksCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..fireworks.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class FireworksCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: FireworksCompatConfig
|
||||
|
||||
def __init__(self, config: FireworksCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="fireworks_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import GroqCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .groq import GroqCompatInferenceAdapter
|
||||
|
||||
adapter = GroqCompatInferenceAdapter(config)
|
||||
return adapter
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class GroqProviderDataValidator(BaseModel):
|
||||
groq_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Groq models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GroqCompatConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The Groq API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.groq.com/openai/v1",
|
||||
description="The URL for the Groq API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.groq.com/openai/v1",
|
||||
"api_key": api_key,
|
||||
}
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.groq_openai_compat.config import GroqCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..groq.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class GroqCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: GroqCompatConfig
|
||||
|
||||
def __init__(self, config: GroqCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="groq_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import SambaNovaCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .sambanova import SambaNovaCompatInferenceAdapter
|
||||
|
||||
adapter = SambaNovaCompatInferenceAdapter(config)
|
||||
return adapter
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class SambaNovaProviderDataValidator(BaseModel):
|
||||
sambanova_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for SambaNova models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SambaNovaCompatConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The SambaNova API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.sambanova.ai/v1",
|
||||
description="The URL for the SambaNova API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.sambanova.ai/v1",
|
||||
"api_key": api_key,
|
||||
}
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.sambanova_openai_compat.config import SambaNovaCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..sambanova.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class SambaNovaCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: SambaNovaCompatConfig
|
||||
|
||||
def __init__(self, config: SambaNovaCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="sambanova_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import TogetherCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .together import TogetherCompatInferenceAdapter
|
||||
|
||||
adapter = TogetherCompatInferenceAdapter(config)
|
||||
return adapter
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class TogetherProviderDataValidator(BaseModel):
|
||||
together_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Together models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TogetherCompatConfig(BaseModel):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The Together API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.together.xyz/v1",
|
||||
description="The URL for the Together API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.TOGETHER_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.together.xyz/v1",
|
||||
"api_key": api_key,
|
||||
}
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.together_openai_compat.config import TogetherCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..together.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class TogetherCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: TogetherCompatConfig
|
||||
|
||||
def __init__(self, config: TogetherCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="together_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
|
@ -161,7 +161,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name: str,
|
||||
name: str | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
chunking_strategy: dict[str, Any] | None = None,
|
||||
|
|
|
@ -83,6 +83,7 @@ class SQLiteTraceStore(TraceStore):
|
|||
)
|
||||
SELECT DISTINCT trace_id, root_span_id, start_time, end_time
|
||||
FROM filtered_traces
|
||||
WHERE root_span_id IS NOT NULL
|
||||
LIMIT {limit} OFFSET {offset}
|
||||
"""
|
||||
|
||||
|
@ -166,7 +167,11 @@ class SQLiteTraceStore(TraceStore):
|
|||
return spans_by_id
|
||||
|
||||
async def get_trace(self, trace_id: str) -> Trace:
|
||||
query = "SELECT * FROM traces WHERE trace_id = ?"
|
||||
query = """
|
||||
SELECT *
|
||||
FROM traces t
|
||||
WHERE t.trace_id = ?
|
||||
"""
|
||||
async with aiosqlite.connect(self.conn_string) as conn:
|
||||
conn.row_factory = aiosqlite.Row
|
||||
async with conn.execute(query, (trace_id,)) as cursor:
|
||||
|
|
|
@ -19,12 +19,7 @@ distribution_spec:
|
|||
- remote::anthropic
|
||||
- remote::gemini
|
||||
- remote::groq
|
||||
- remote::fireworks-openai-compat
|
||||
- remote::llama-openai-compat
|
||||
- remote::together-openai-compat
|
||||
- remote::groq-openai-compat
|
||||
- remote::sambanova-openai-compat
|
||||
- remote::cerebras-openai-compat
|
||||
- remote::sambanova
|
||||
- remote::passthrough
|
||||
- inline::sentence-transformers
|
||||
|
|
|
@ -90,36 +90,11 @@ providers:
|
|||
config:
|
||||
url: https://api.groq.com
|
||||
api_key: ${env.GROQ_API_KEY}
|
||||
- provider_id: ${env.ENABLE_FIREWORKS_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::fireworks-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY}
|
||||
- provider_id: ${env.ENABLE_LLAMA_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::llama-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.llama.com/compat/v1/
|
||||
api_key: ${env.LLAMA_API_KEY}
|
||||
- provider_id: ${env.ENABLE_TOGETHER_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::together-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY}
|
||||
- provider_id: ${env.ENABLE_GROQ_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::groq-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.groq.com/openai/v1
|
||||
api_key: ${env.GROQ_API_KEY}
|
||||
- provider_id: ${env.ENABLE_SAMBANOVA_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::sambanova-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY}
|
||||
- provider_id: ${env.ENABLE_CEREBRAS_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::cerebras-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.cerebras.ai/v1
|
||||
api_key: ${env.CEREBRAS_API_KEY}
|
||||
- provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
|
||||
provider_type: remote::sambanova
|
||||
config:
|
||||
|
|
|
@ -19,12 +19,7 @@ distribution_spec:
|
|||
- remote::anthropic
|
||||
- remote::gemini
|
||||
- remote::groq
|
||||
- remote::fireworks-openai-compat
|
||||
- remote::llama-openai-compat
|
||||
- remote::together-openai-compat
|
||||
- remote::groq-openai-compat
|
||||
- remote::sambanova-openai-compat
|
||||
- remote::cerebras-openai-compat
|
||||
- remote::sambanova
|
||||
- remote::passthrough
|
||||
- inline::sentence-transformers
|
||||
|
|
|
@ -90,36 +90,11 @@ providers:
|
|||
config:
|
||||
url: https://api.groq.com
|
||||
api_key: ${env.GROQ_API_KEY}
|
||||
- provider_id: ${env.ENABLE_FIREWORKS_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::fireworks-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY}
|
||||
- provider_id: ${env.ENABLE_LLAMA_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::llama-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.llama.com/compat/v1/
|
||||
api_key: ${env.LLAMA_API_KEY}
|
||||
- provider_id: ${env.ENABLE_TOGETHER_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::together-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY}
|
||||
- provider_id: ${env.ENABLE_GROQ_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::groq-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.groq.com/openai/v1
|
||||
api_key: ${env.GROQ_API_KEY}
|
||||
- provider_id: ${env.ENABLE_SAMBANOVA_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::sambanova-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY}
|
||||
- provider_id: ${env.ENABLE_CEREBRAS_OPENAI_COMPAT:=__disabled__}
|
||||
provider_type: remote::cerebras-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.cerebras.ai/v1
|
||||
api_key: ${env.CEREBRAS_API_KEY}
|
||||
- provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
|
||||
provider_type: remote::sambanova
|
||||
config:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue