mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 04:32:25 +00:00
Merge branch 'main' into make-provider-model-id-non-optional
This commit is contained in:
commit
bd9bad5f84
176 changed files with 4121 additions and 2735 deletions
|
|
@ -87,6 +87,20 @@ class RAGQueryGenerator(Enum):
|
|||
custom = "custom"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGSearchMode(Enum):
|
||||
"""
|
||||
Search modes for RAG query retrieval:
|
||||
- VECTOR: Uses vector similarity search for semantic matching
|
||||
- KEYWORD: Uses keyword-based search for exact matching
|
||||
- HYBRID: Combines both vector and keyword search for better results
|
||||
"""
|
||||
|
||||
VECTOR = "vector"
|
||||
KEYWORD = "keyword"
|
||||
HYBRID = "hybrid"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DefaultRAGQueryGeneratorConfig(BaseModel):
|
||||
type: Literal["default"] = "default"
|
||||
|
|
@ -128,7 +142,7 @@ class RAGQueryConfig(BaseModel):
|
|||
max_tokens_in_context: int = 4096
|
||||
max_chunks: int = 5
|
||||
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
||||
mode: str | None = None
|
||||
mode: RAGSearchMode | None = RAGSearchMode.VECTOR
|
||||
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
|
||||
|
||||
@field_validator("chunk_template")
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ class VectorDB(Resource):
|
|||
|
||||
embedding_model: str
|
||||
embedding_dimension: int
|
||||
vector_db_name: str | None = None
|
||||
|
||||
@property
|
||||
def vector_db_id(self) -> str:
|
||||
|
|
@ -70,6 +71,7 @@ class VectorDBs(Protocol):
|
|||
embedding_model: str,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
vector_db_name: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorDB:
|
||||
"""Register a vector database.
|
||||
|
|
@ -78,6 +80,7 @@ class VectorDBs(Protocol):
|
|||
:param embedding_model: The embedding model to use.
|
||||
:param embedding_dimension: The dimension of the embedding model.
|
||||
:param provider_id: The identifier of the provider.
|
||||
:param vector_db_name: The name of the vector database.
|
||||
:param provider_vector_db_id: The identifier of the vector database in the provider.
|
||||
:returns: A VectorDB.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -346,7 +346,6 @@ class VectorIO(Protocol):
|
|||
embedding_model: str | None = None,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorStoreObject:
|
||||
"""Creates a vector store.
|
||||
|
||||
|
|
@ -358,7 +357,6 @@ class VectorIO(Protocol):
|
|||
:param embedding_model: The embedding model to use for this vector store.
|
||||
:param embedding_dimension: The dimension of the embedding vectors (default: 384).
|
||||
:param provider_id: The ID of the provider to use for this vector store.
|
||||
:param provider_vector_db_id: The provider-specific vector database ID.
|
||||
:returns: A VectorStoreObject representing the created vector store.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
sys.exit(1)
|
||||
elif args.providers:
|
||||
providers = dict()
|
||||
providers_list: dict[str, str | list[str]] = dict()
|
||||
for api_provider in args.providers.split(","):
|
||||
if "=" not in api_provider:
|
||||
cprint(
|
||||
|
|
@ -112,7 +112,15 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
sys.exit(1)
|
||||
if provider in providers_for_api:
|
||||
providers.setdefault(api, []).append(provider)
|
||||
if api not in providers_list:
|
||||
providers_list[api] = []
|
||||
# Use type guarding to ensure we have a list
|
||||
provider_value = providers_list[api]
|
||||
if isinstance(provider_value, list):
|
||||
provider_value.append(provider)
|
||||
else:
|
||||
# Convert string to list and append
|
||||
providers_list[api] = [provider_value, provider]
|
||||
else:
|
||||
cprint(
|
||||
f"{provider} is not a valid provider for the {api} API.",
|
||||
|
|
@ -121,7 +129,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
sys.exit(1)
|
||||
distribution_spec = DistributionSpec(
|
||||
providers=providers,
|
||||
providers=providers_list,
|
||||
description=",".join(args.providers),
|
||||
)
|
||||
if not args.image_type:
|
||||
|
|
@ -182,7 +190,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
|
||||
cprint("Tip: use <TAB> to see options for the providers.\n", color="green", file=sys.stderr)
|
||||
|
||||
providers = dict()
|
||||
providers: dict[str, str | list[str]] = dict()
|
||||
for api, providers_for_api in get_provider_registry().items():
|
||||
available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")]
|
||||
if not available_providers:
|
||||
|
|
@ -371,10 +379,16 @@ def _run_stack_build_command_from_build_config(
|
|||
if not image_name:
|
||||
raise ValueError("Please specify an image name when building a venv image")
|
||||
|
||||
# At this point, image_name should be guaranteed to be a string
|
||||
if image_name is None:
|
||||
raise ValueError("image_name should not be None after validation")
|
||||
|
||||
if template_name:
|
||||
build_dir = DISTRIBS_BASE_DIR / template_name
|
||||
build_file_path = build_dir / f"{template_name}-build.yaml"
|
||||
else:
|
||||
if image_name is None:
|
||||
raise ValueError("image_name cannot be None")
|
||||
build_dir = DISTRIBS_BASE_DIR / image_name
|
||||
build_file_path = build_dir / f"{image_name}-build.yaml"
|
||||
|
||||
|
|
@ -395,7 +409,7 @@ def _run_stack_build_command_from_build_config(
|
|||
build_file_path,
|
||||
image_name,
|
||||
template_or_config=template_name or config_path or str(build_file_path),
|
||||
run_config=run_config_file,
|
||||
run_config=run_config_file.as_posix() if run_config_file else None,
|
||||
)
|
||||
if return_code != 0:
|
||||
raise RuntimeError(f"Failed to build image {image_name}")
|
||||
|
|
|
|||
|
|
@ -83,46 +83,57 @@ class StackRun(Subcommand):
|
|||
return ImageType.CONDA.value, args.image_name
|
||||
return args.image_type, args.image_name
|
||||
|
||||
def _resolve_config_and_template(self, args: argparse.Namespace) -> tuple[Path | None, str | None]:
|
||||
"""Resolve config file path and template name from args.config"""
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
if not args.config:
|
||||
return None, None
|
||||
|
||||
config_file = Path(args.config)
|
||||
has_yaml_suffix = args.config.endswith(".yaml")
|
||||
template_name = None
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if this is a template
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
|
||||
if config_file.exists():
|
||||
template_name = args.config
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to ~/.llama dir
|
||||
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists():
|
||||
self.parser.error(
|
||||
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
|
||||
)
|
||||
|
||||
if not config_file.is_file():
|
||||
self.parser.error(
|
||||
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
|
||||
)
|
||||
|
||||
return config_file, template_name
|
||||
|
||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||
import yaml
|
||||
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
|
||||
|
||||
if args.enable_ui:
|
||||
self._start_ui_development_server(args.port)
|
||||
image_type, image_name = self._get_image_type_and_name(args)
|
||||
|
||||
# Resolve config file and template name first
|
||||
config_file, template_name = self._resolve_config_and_template(args)
|
||||
|
||||
# Check if config is required based on image type
|
||||
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not args.config:
|
||||
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not config_file:
|
||||
self.parser.error("Config file is required for venv and conda environments")
|
||||
|
||||
if args.config:
|
||||
config_file = Path(args.config)
|
||||
has_yaml_suffix = args.config.endswith(".yaml")
|
||||
template_name = None
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if this is a template
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
|
||||
if config_file.exists():
|
||||
template_name = args.config
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to ~/.llama dir
|
||||
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists():
|
||||
self.parser.error(
|
||||
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
|
||||
)
|
||||
|
||||
if not config_file.is_file():
|
||||
self.parser.error(
|
||||
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
|
||||
)
|
||||
|
||||
if config_file:
|
||||
logger.info(f"Using run configuration: {config_file}")
|
||||
|
||||
try:
|
||||
|
|
@ -138,8 +149,6 @@ class StackRun(Subcommand):
|
|||
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
||||
else:
|
||||
config = None
|
||||
config_file = None
|
||||
template_name = None
|
||||
|
||||
# If neither image type nor image name is provided, assume the server should be run directly
|
||||
# using the current environment packages.
|
||||
|
|
@ -172,10 +181,7 @@ class StackRun(Subcommand):
|
|||
run_args.extend([str(args.port)])
|
||||
|
||||
if config_file:
|
||||
if template_name:
|
||||
run_args.extend(["--template", str(template_name)])
|
||||
else:
|
||||
run_args.extend(["--config", str(config_file)])
|
||||
run_args.extend(["--config", str(config_file)])
|
||||
|
||||
if args.env:
|
||||
for env_var in args.env:
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ def is_action_allowed(
|
|||
if not len(policy):
|
||||
policy = default_policy()
|
||||
|
||||
qualified_resource_id = resource.type + "::" + resource.identifier
|
||||
qualified_resource_id = f"{resource.type}::{resource.identifier}"
|
||||
for rule in policy:
|
||||
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
|
||||
if rule.when:
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ FROM $container_base
|
|||
WORKDIR /app
|
||||
|
||||
# We install the Python 3.12 dev headers and build tools so that any
|
||||
# C‑extension wheels (e.g. polyleven, faiss‑cpu) can compile successfully.
|
||||
# C-extension wheels (e.g. polyleven, faiss-cpu) can compile successfully.
|
||||
|
||||
RUN dnf -y update && dnf install -y iputils git net-tools wget \
|
||||
vim-minimal python3.12 python3.12-pip python3.12-wheel \
|
||||
|
|
@ -169,7 +169,7 @@ if [ -n "$run_config" ]; then
|
|||
echo "Copying external providers directory: $external_providers_dir"
|
||||
cp -r "$external_providers_dir" "$BUILD_CONTEXT_DIR/providers.d"
|
||||
add_to_container << EOF
|
||||
COPY --chmod=g+w providers.d /.llama/providers.d
|
||||
COPY providers.d /.llama/providers.d
|
||||
EOF
|
||||
fi
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from llama_stack.distribution.distribution import (
|
|||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
)
|
||||
from llama_stack.distribution.stack import replace_env_vars
|
||||
from llama_stack.distribution.stack import cast_image_name_to_string, replace_env_vars
|
||||
from llama_stack.distribution.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||
|
|
@ -164,7 +164,8 @@ def upgrade_from_routing_table(
|
|||
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
|
||||
version = config_dict.get("version", None)
|
||||
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
||||
return StackRunConfig(**replace_env_vars(config_dict))
|
||||
processed_config_dict = replace_env_vars(config_dict)
|
||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
||||
|
||||
if "routing_table" in config_dict:
|
||||
logger.info("Upgrading config...")
|
||||
|
|
@ -175,4 +176,5 @@ def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfi
|
|||
if not config_dict.get("external_providers_dir", None):
|
||||
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
|
||||
|
||||
return StackRunConfig(**replace_env_vars(config_dict))
|
||||
processed_config_dict = replace_env_vars(config_dict)
|
||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
|
@ -81,6 +82,7 @@ class VectorIORouter(VectorIO):
|
|||
embedding_model: str,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
vector_db_name: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> None:
|
||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||
|
|
@ -89,6 +91,7 @@ class VectorIORouter(VectorIO):
|
|||
embedding_model,
|
||||
embedding_dimension,
|
||||
provider_id,
|
||||
vector_db_name,
|
||||
provider_vector_db_id,
|
||||
)
|
||||
|
||||
|
|
@ -123,7 +126,6 @@ class VectorIORouter(VectorIO):
|
|||
embedding_model: str | None = None,
|
||||
embedding_dimension: int | None = None,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
|
||||
|
||||
|
|
@ -135,17 +137,17 @@ class VectorIORouter(VectorIO):
|
|||
embedding_model, embedding_dimension = embedding_model_info
|
||||
logger.info(f"No embedding model specified, using first available: {embedding_model}")
|
||||
|
||||
vector_db_id = name
|
||||
vector_db_id = f"vs_{uuid.uuid4()}"
|
||||
registered_vector_db = await self.routing_table.register_vector_db(
|
||||
vector_db_id,
|
||||
embedding_model,
|
||||
embedding_dimension,
|
||||
provider_id,
|
||||
provider_vector_db_id,
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=provider_id,
|
||||
provider_vector_db_id=vector_db_id,
|
||||
vector_db_name=name,
|
||||
)
|
||||
|
||||
return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store(
|
||||
vector_db_id,
|
||||
name=name,
|
||||
file_ids=file_ids,
|
||||
expires_after=expires_after,
|
||||
chunking_strategy=chunking_strategy,
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
vector_db_name: str | None = None,
|
||||
) -> VectorDB:
|
||||
if provider_vector_db_id is None:
|
||||
provider_vector_db_id = vector_db_id
|
||||
|
|
@ -62,6 +63,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
"provider_resource_id": provider_vector_db_id,
|
||||
"embedding_model": embedding_model,
|
||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||
"vector_db_name": vector_db_name,
|
||||
}
|
||||
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
|
||||
await self.register_object(vector_db)
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ from llama_stack.distribution.server.routes import (
|
|||
initialize_route_impls,
|
||||
)
|
||||
from llama_stack.distribution.stack import (
|
||||
cast_image_name_to_string,
|
||||
construct_stack,
|
||||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
|
|
@ -439,13 +440,13 @@ def main(args: argparse.Namespace | None = None):
|
|||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
config = replace_env_vars(config_contents)
|
||||
config = StackRunConfig(**config)
|
||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||
|
||||
# now that the logger is initialized, print the line about which type of config we are using.
|
||||
logger.info(log_line)
|
||||
|
||||
logger.info("Run configuration:")
|
||||
safe_config = redact_sensitive_fields(config.model_dump())
|
||||
safe_config = redact_sensitive_fields(config.model_dump(mode="json"))
|
||||
logger.info(yaml.dump(safe_config, indent=2))
|
||||
|
||||
app = FastAPI(
|
||||
|
|
|
|||
|
|
@ -98,6 +98,7 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
|||
|
||||
method = getattr(impls[api], register_method)
|
||||
for obj in objects:
|
||||
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
|
||||
# Do not register models on disabled providers
|
||||
if hasattr(obj, "provider_id") and obj.provider_id is not None and obj.provider_id == "__disabled__":
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
|
||||
|
|
@ -112,6 +113,11 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
|||
):
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled model.")
|
||||
continue
|
||||
|
||||
if hasattr(obj, "shield_id") and obj.shield_id is not None and obj.shield_id == "__disabled__":
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled shield.")
|
||||
continue
|
||||
|
||||
# we want to maintain the type information in arguments to method.
|
||||
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
|
||||
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
|
||||
|
|
@ -261,6 +267,13 @@ def _convert_string_to_proper_type(value: str) -> Any:
|
|||
return value
|
||||
|
||||
|
||||
def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Ensure that any value for a key 'image_name' in a config_dict is a string"""
|
||||
if "image_name" in config_dict and config_dict["image_name"] is not None:
|
||||
config_dict["image_name"] = str(config_dict["image_name"])
|
||||
return config_dict
|
||||
|
||||
|
||||
def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
||||
"""Validate and split an environment variable key-value pair."""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -6,12 +6,9 @@
|
|||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextvars import ContextVar
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def preserve_contexts_async_generator(
|
||||
def preserve_contexts_async_generator[T](
|
||||
gen: AsyncGenerator[T, None], context_vars: list[ContextVar]
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -51,6 +51,9 @@ class LocalfsFilesImpl(Files):
|
|||
},
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def _generate_file_id(self) -> str:
|
||||
"""Generate a unique file ID for OpenAI API."""
|
||||
return f"file-{uuid.uuid4().hex}"
|
||||
|
|
|
|||
|
|
@ -123,7 +123,8 @@ class TorchtunePostTrainingImpl:
|
|||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
) -> PostTrainingJob: ...
|
||||
) -> PostTrainingJob:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
return ListPostTrainingJobsResponse(
|
||||
|
|
|
|||
|
|
@ -146,10 +146,9 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
if shield.provider_resource_id not in LLAMA_GUARD_MODEL_IDS:
|
||||
raise ValueError(
|
||||
f"Unsupported Llama Guard type: {shield.provider_resource_id}. Allowed types: {LLAMA_GUARD_MODEL_IDS}"
|
||||
)
|
||||
# Allow any model to be registered as a shield
|
||||
# The model will be validated during runtime when making inference calls
|
||||
pass
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
|
@ -167,11 +166,25 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||
messages[0] = UserMessage(content=messages[0].content)
|
||||
|
||||
model = LLAMA_GUARD_MODEL_IDS[shield.provider_resource_id]
|
||||
# Use the inference API's model resolution instead of hardcoded mappings
|
||||
# This allows the shield to work with any registered model
|
||||
model_id = shield.provider_resource_id
|
||||
|
||||
# Determine safety categories based on the model type
|
||||
# For known Llama Guard models, use specific categories
|
||||
if model_id in LLAMA_GUARD_MODEL_IDS:
|
||||
# Use the mapped model for categories but the original model_id for inference
|
||||
mapped_model = LLAMA_GUARD_MODEL_IDS[model_id]
|
||||
safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
|
||||
else:
|
||||
# For unknown models, use default Llama Guard 3 8B categories
|
||||
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
|
||||
|
||||
impl = LlamaGuardShield(
|
||||
model=model,
|
||||
model=model_id,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=self.config.excluded_categories,
|
||||
safety_categories=safety_categories,
|
||||
)
|
||||
|
||||
return await impl.run(messages)
|
||||
|
|
@ -183,20 +196,21 @@ class LlamaGuardShield:
|
|||
model: str,
|
||||
inference_api: Inference,
|
||||
excluded_categories: list[str] | None = None,
|
||||
safety_categories: list[str] | None = None,
|
||||
):
|
||||
if excluded_categories is None:
|
||||
excluded_categories = []
|
||||
if safety_categories is None:
|
||||
safety_categories = []
|
||||
|
||||
assert len(excluded_categories) == 0 or all(
|
||||
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
||||
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
||||
|
||||
if model not in MODEL_TO_SAFETY_CATEGORIES_MAP:
|
||||
raise ValueError(f"Unsupported model: {model}")
|
||||
|
||||
self.model = model
|
||||
self.inference_api = inference_api
|
||||
self.excluded_categories = excluded_categories
|
||||
self.safety_categories = safety_categories
|
||||
|
||||
def check_unsafe_response(self, response: str) -> str | None:
|
||||
match = re.match(r"^unsafe\n(.*)$", response)
|
||||
|
|
@ -214,7 +228,7 @@ class LlamaGuardShield:
|
|||
|
||||
final_categories = []
|
||||
|
||||
all_categories = MODEL_TO_SAFETY_CATEGORIES_MAP[self.model]
|
||||
all_categories = self.safety_categories
|
||||
for cat in all_categories:
|
||||
cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat]
|
||||
if cat_code in excluded_categories:
|
||||
|
|
|
|||
|
|
@ -181,8 +181,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
||||
# Load existing OpenAI vector stores using the mixin method
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
# Load existing OpenAI vector stores into the in-memory cache
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
# Cleanup if needed
|
||||
|
|
@ -261,42 +261,10 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
# OpenAI Vector Store Mixin abstract method implementations
|
||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Save vector store metadata to kvstore."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||
|
||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||
"""Load all vector store metadata from kvstore."""
|
||||
assert self.kvstore is not None
|
||||
start_key = OPENAI_VECTOR_STORES_PREFIX
|
||||
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
|
||||
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
stores = {}
|
||||
for store_data in stored_openai_stores:
|
||||
store_info = json.loads(store_data)
|
||||
stores[store_info["id"]] = store_info
|
||||
return stores
|
||||
|
||||
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Update vector store metadata in kvstore."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||
|
||||
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
||||
"""Delete vector store metadata from kvstore."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.delete(key)
|
||||
|
||||
async def _save_openai_vector_store_file(
|
||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Save vector store file metadata to kvstore."""
|
||||
"""Save vector store file data to kvstore."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(file_info))
|
||||
|
|
@ -324,7 +292,16 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
await self.kvstore.set(key=key, value=json.dumps(file_info))
|
||||
|
||||
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||
"""Delete vector store file metadata from kvstore."""
|
||||
"""Delete vector store data from kvstore."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
|
||||
await self.kvstore.delete(key)
|
||||
|
||||
keys_to_delete = [
|
||||
f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}",
|
||||
f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}",
|
||||
]
|
||||
for key in keys_to_delete:
|
||||
try:
|
||||
await self.kvstore.delete(key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete key {key}: {e}")
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from llama_stack.schema_utils import json_schema_type
|
|||
@json_schema_type
|
||||
class MilvusVectorIOConfig(BaseModel):
|
||||
db_path: str
|
||||
kvstore: KVStoreConfig
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
|
||||
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -6,14 +6,24 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
|
||||
|
||||
class SQLiteVectorIOConfig(BaseModel):
|
||||
db_path: str
|
||||
db_path: str = Field(description="Path to the SQLite database file")
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + "sqlite_vec.db",
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="sqlite_vec_registry.db",
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import sqlite3
|
||||
import struct
|
||||
from typing import Any
|
||||
|
|
@ -24,6 +25,8 @@ from llama_stack.apis.vector_io import (
|
|||
VectorIO,
|
||||
)
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_RRF,
|
||||
|
|
@ -40,6 +43,13 @@ KEYWORD_SEARCH = "keyword"
|
|||
HYBRID_SEARCH = "hybrid"
|
||||
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH}
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:sqlite_vec:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:sqlite_vec:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:sqlite_vec:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:sqlite_vec:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:sqlite_vec:{VERSION}::"
|
||||
|
||||
|
||||
def serialize_vector(vector: list[float]) -> bytes:
|
||||
"""Serialize a list of floats into a compact binary representation."""
|
||||
|
|
@ -108,6 +118,10 @@ def _rrf_rerank(
|
|||
return rrf_scores
|
||||
|
||||
|
||||
def _make_sql_identifier(name: str) -> str:
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
||||
|
||||
|
||||
class SQLiteVecIndex(EmbeddingIndex):
|
||||
"""
|
||||
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
|
||||
|
|
@ -117,13 +131,14 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
- An FTS5 table (fts_chunks_{bank_id}) for full-text keyword search.
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int, db_path: str, bank_id: str):
|
||||
def __init__(self, dimension: int, db_path: str, bank_id: str, kvstore: KVStore | None = None):
|
||||
self.dimension = dimension
|
||||
self.db_path = db_path
|
||||
self.bank_id = bank_id
|
||||
self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
|
||||
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
|
||||
self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_")
|
||||
self.metadata_table = _make_sql_identifier(f"chunks_{bank_id}")
|
||||
self.vector_table = _make_sql_identifier(f"vec_chunks_{bank_id}")
|
||||
self.fts_table = _make_sql_identifier(f"fts_chunks_{bank_id}")
|
||||
self.kvstore = kvstore
|
||||
|
||||
@classmethod
|
||||
async def create(cls, dimension: int, db_path: str, bank_id: str):
|
||||
|
|
@ -138,14 +153,14 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
try:
|
||||
# Create the table to store chunk metadata.
|
||||
cur.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
|
||||
CREATE TABLE IF NOT EXISTS [{self.metadata_table}] (
|
||||
id TEXT PRIMARY KEY,
|
||||
chunk TEXT
|
||||
);
|
||||
""")
|
||||
# Create the virtual table for embeddings.
|
||||
cur.execute(f"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS [{self.vector_table}]
|
||||
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
|
||||
""")
|
||||
connection.commit()
|
||||
|
|
@ -153,7 +168,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
# based on query. Implementation of the change on client side will allow passing the search_mode option
|
||||
# during initialization to make it easier to create the table that is required.
|
||||
cur.execute(f"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS {self.fts_table}
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS [{self.fts_table}]
|
||||
USING fts5(id, content);
|
||||
""")
|
||||
connection.commit()
|
||||
|
|
@ -168,9 +183,9 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
connection = _create_sqlite_connection(self.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.fts_table};")
|
||||
cur.execute(f"DROP TABLE IF EXISTS [{self.metadata_table}];")
|
||||
cur.execute(f"DROP TABLE IF EXISTS [{self.vector_table}];")
|
||||
cur.execute(f"DROP TABLE IF EXISTS [{self.fts_table}];")
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
|
|
@ -202,7 +217,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
metadata_data = [(chunk.chunk_id, chunk.model_dump_json()) for chunk in batch_chunks]
|
||||
cur.executemany(
|
||||
f"""
|
||||
INSERT INTO {self.metadata_table} (id, chunk)
|
||||
INSERT INTO [{self.metadata_table}] (id, chunk)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk;
|
||||
""",
|
||||
|
|
@ -220,7 +235,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||
]
|
||||
cur.executemany(
|
||||
f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);",
|
||||
f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);",
|
||||
embedding_data,
|
||||
)
|
||||
|
||||
|
|
@ -228,13 +243,13 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
|
||||
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
|
||||
cur.executemany(
|
||||
f"DELETE FROM {self.fts_table} WHERE id = ?;",
|
||||
f"DELETE FROM [{self.fts_table}] WHERE id = ?;",
|
||||
[(row[0],) for row in fts_data],
|
||||
)
|
||||
|
||||
# INSERT new entries
|
||||
cur.executemany(
|
||||
f"INSERT INTO {self.fts_table} (id, content) VALUES (?, ?);",
|
||||
f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);",
|
||||
fts_data,
|
||||
)
|
||||
|
||||
|
|
@ -270,8 +285,8 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
emb_blob = serialize_vector(emb_list)
|
||||
query_sql = f"""
|
||||
SELECT m.id, m.chunk, v.distance
|
||||
FROM {self.vector_table} AS v
|
||||
JOIN {self.metadata_table} AS m ON m.id = v.id
|
||||
FROM [{self.vector_table}] AS v
|
||||
JOIN [{self.metadata_table}] AS m ON m.id = v.id
|
||||
WHERE v.embedding MATCH ? AND k = ?
|
||||
ORDER BY v.distance;
|
||||
"""
|
||||
|
|
@ -312,9 +327,9 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
cur = connection.cursor()
|
||||
try:
|
||||
query_sql = f"""
|
||||
SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score
|
||||
FROM {self.fts_table} AS f
|
||||
JOIN {self.metadata_table} AS m ON m.id = f.id
|
||||
SELECT DISTINCT m.id, m.chunk, bm25([{self.fts_table}]) AS score
|
||||
FROM [{self.fts_table}] AS f
|
||||
JOIN [{self.metadata_table}] AS m ON m.id = f.id
|
||||
WHERE f.content MATCH ?
|
||||
ORDER BY score ASC
|
||||
LIMIT ?;
|
||||
|
|
@ -425,27 +440,81 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
self.files_api = files_api
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.kvstore: KVStore | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
def _setup_connection():
|
||||
# Open a connection to the SQLite database (the file is specified in the config).
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||
for db_json in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(db_json)
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_db.identifier,
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
# Load existing OpenAI vector stores into the in-memory cache
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
# nothing to do since we don't maintain a persistent connection
|
||||
pass
|
||||
|
||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||
return [v.vector_db for v in self.cache.values()]
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_db.identifier,
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
if self.vector_db_store is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
vector_db = self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=SQLiteVecIndex(
|
||||
dimension=vector_db.embedding_dimension,
|
||||
db_path=self.config.db_path,
|
||||
bank_id=vector_db.identifier,
|
||||
kvstore=self.kvstore,
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db_id] = index
|
||||
return index
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||
return
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
async def _save_openai_vector_store_file(
|
||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Save vector store file metadata to SQLite database."""
|
||||
|
||||
def _create_or_store():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
# Create a table to persist vector DB registrations.
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS vector_dbs (
|
||||
id TEXT PRIMARY KEY,
|
||||
metadata TEXT
|
||||
);
|
||||
""")
|
||||
# Create a table to persist OpenAI vector stores.
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS openai_vector_stores (
|
||||
id TEXT PRIMARY KEY,
|
||||
metadata TEXT
|
||||
);
|
||||
""")
|
||||
# Create a table to persist OpenAI vector store files.
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS openai_vector_store_files (
|
||||
|
|
@ -464,168 +533,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
);
|
||||
""")
|
||||
connection.commit()
|
||||
# Load any existing vector DB registrations.
|
||||
cur.execute("SELECT metadata FROM vector_dbs")
|
||||
vector_db_rows = cur.fetchall()
|
||||
return vector_db_rows
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
vector_db_rows = await asyncio.to_thread(_setup_connection)
|
||||
|
||||
# Load existing vector DBs
|
||||
for row in vector_db_rows:
|
||||
vector_db_data = row[0]
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_db.identifier,
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
# Load existing OpenAI vector stores using the mixin method
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
# nothing to do since we don't maintain a persistent connection
|
||||
pass
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
def _register_db():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)",
|
||||
(vector_db.identifier, vector_db.model_dump_json()),
|
||||
)
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_register_db)
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_db.identifier,
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||
return [v.vector_db for v in self.cache.values()]
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||
return
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
def _delete_vector_db_from_registry():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,))
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_delete_vector_db_from_registry)
|
||||
|
||||
# OpenAI Vector Store Mixin abstract method implementations
|
||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Save vector store metadata to SQLite database."""
|
||||
|
||||
def _store():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)",
|
||||
(store_id, json.dumps(store_info)),
|
||||
)
|
||||
connection.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving openai vector store {store_id}: {e}")
|
||||
raise
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_store)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving openai vector store {store_id}: {e}")
|
||||
raise
|
||||
|
||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||
"""Load all vector store metadata from SQLite database."""
|
||||
|
||||
def _load():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute("SELECT metadata FROM openai_vector_stores")
|
||||
rows = cur.fetchall()
|
||||
return rows
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
rows = await asyncio.to_thread(_load)
|
||||
stores = {}
|
||||
for row in rows:
|
||||
store_data = row[0]
|
||||
store_info = json.loads(store_data)
|
||||
stores[store_info["id"]] = store_info
|
||||
return stores
|
||||
|
||||
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Update vector store metadata in SQLite database."""
|
||||
|
||||
def _update():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"UPDATE openai_vector_stores SET metadata = ? WHERE id = ?",
|
||||
(json.dumps(store_info), store_id),
|
||||
)
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_update)
|
||||
|
||||
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
||||
"""Delete vector store metadata from SQLite database."""
|
||||
|
||||
def _delete():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (store_id,))
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_delete)
|
||||
|
||||
async def _save_openai_vector_store_file(
|
||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Save vector store file metadata to SQLite database."""
|
||||
|
||||
def _store():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"INSERT OR REPLACE INTO openai_vector_store_files (store_id, file_id, metadata) VALUES (?, ?, ?)",
|
||||
(store_id, file_id, json.dumps(file_info)),
|
||||
|
|
@ -643,7 +550,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
connection.close()
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_store)
|
||||
await asyncio.to_thread(_create_or_store)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving openai vector store file {store_id} {file_id}: {e}")
|
||||
raise
|
||||
|
|
@ -722,6 +629,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
cur.execute(
|
||||
"DELETE FROM openai_vector_store_files WHERE store_id = ? AND file_id = ?", (store_id, file_id)
|
||||
)
|
||||
cur.execute(
|
||||
"DELETE FROM openai_vector_store_files_contents WHERE store_id = ? AND file_id = ?",
|
||||
(store_id, file_id),
|
||||
)
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
|
|
@ -730,15 +641,17 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
await asyncio.to_thread(_delete)
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
|
||||
# and then call our index's add_chunks.
|
||||
await self.cache[vector_db_id].insert_chunks(chunks)
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
if vector_db_id not in self.cache:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
return await self.cache[vector_db_id].query_chunks(query, params)
|
||||
return await index.query_chunks(query, params)
|
||||
|
|
|
|||
|
|
@ -15,21 +15,26 @@ LLM_MODEL_IDS = [
|
|||
"anthropic/claude-3-5-haiku-latest",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3-lite",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 512, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-code-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
]
|
||||
MODEL_ENTRIES = (
|
||||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3-lite",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 512, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-code-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
]
|
||||
+ SAFETY_MODELS_ENTRIES
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
|
|
@ -22,4 +26,4 @@ MODEL_ENTRIES = [
|
|||
"meta.llama3-1-405b-instruct-v1:0",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -9,6 +9,9 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
# https://inference-docs.cerebras.ai/models
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1-8b",
|
||||
|
|
@ -18,4 +21,8 @@ MODEL_ENTRIES = [
|
|||
"llama-3.3-70b",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
]
|
||||
build_hf_repo_model_entry(
|
||||
"llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -47,7 +47,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import DatabricksImplConfig
|
||||
|
||||
model_entries = [
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
# https://docs.databricks.com/aws/en/machine-learning/model-serving/foundation-model-overview
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"databricks-meta-llama-3-1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
|
|
@ -56,7 +59,7 @@ model_entries = [
|
|||
"databricks-meta-llama-3-1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
||||
|
||||
class DatabricksInferenceAdapter(
|
||||
|
|
@ -66,7 +69,7 @@ class DatabricksInferenceAdapter(
|
|||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_entries=model_entries)
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
|||
|
|
@ -11,6 +11,17 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-guard-3-8b",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-guard-3-11b-vision",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
]
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||
|
|
@ -40,14 +51,6 @@ MODEL_ENTRIES = [
|
|||
"accounts/fireworks/models/llama-v3p3-70b-instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-guard-3-8b",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-guard-3-11b-vision",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama4-scout-instruct-basic",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
|
|
@ -64,4 +67,4 @@ MODEL_ENTRIES = [
|
|||
"context_length": 8192,
|
||||
},
|
||||
),
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -17,11 +17,16 @@ LLM_MODEL_IDS = [
|
|||
"gemini/gemini-2.5-pro",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="gemini/text-embedding-004",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 768, "context_length": 2048},
|
||||
),
|
||||
]
|
||||
MODEL_ENTRIES = (
|
||||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="gemini/text-embedding-004",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 768, "context_length": 2048},
|
||||
),
|
||||
]
|
||||
+ SAFETY_MODELS_ENTRIES
|
||||
)
|
||||
|
|
|
|||
|
|
@ -38,24 +38,18 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
provider_data_api_key_field="groq_api_key",
|
||||
)
|
||||
self.config = config
|
||||
self._openai_client = None
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
if self._openai_client:
|
||||
await self._openai_client.close()
|
||||
self._openai_client = None
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
if not self._openai_client:
|
||||
self._openai_client = AsyncOpenAI(
|
||||
base_url=f"{self.config.url}/openai/v1",
|
||||
api_key=self.config.api_key,
|
||||
)
|
||||
return self._openai_client
|
||||
return AsyncOpenAI(
|
||||
base_url=f"{self.config.url}/openai/v1",
|
||||
api_key=self.get_api_key(),
|
||||
)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama3-8b-8192",
|
||||
|
|
@ -51,4 +53,4 @@ MODEL_ENTRIES = [
|
|||
"groq/meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
# https://docs.nvidia.com/nim/large-language-models/latest/supported-llm-agnostic-architectures.html
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama3-8b-instruct",
|
||||
|
|
@ -99,4 +102,4 @@ MODEL_ENTRIES = [
|
|||
),
|
||||
# TODO(mf): how do we handle Nemotron models?
|
||||
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
import logging
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
|
||||
|
|
@ -93,41 +92,21 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
|
||||
self._config = config
|
||||
|
||||
@lru_cache # noqa: B019
|
||||
def _get_client(self, provider_model_id: str) -> AsyncOpenAI:
|
||||
@property
|
||||
def _client(self) -> AsyncOpenAI:
|
||||
"""
|
||||
For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However,
|
||||
some models are hosted on different URLs. This function returns the appropriate client
|
||||
for the given provider_model_id.
|
||||
Returns an OpenAI client for the configured NVIDIA API endpoint.
|
||||
|
||||
This relies on lru_cache and self._default_client to avoid creating a new client for each request
|
||||
or for each model that is hosted on https://integrate.api.nvidia.com/v1.
|
||||
|
||||
:param provider_model_id: The provider model ID
|
||||
:return: An OpenAI client
|
||||
"""
|
||||
|
||||
@lru_cache # noqa: B019
|
||||
def _get_client_for_base_url(base_url: str) -> AsyncOpenAI:
|
||||
"""
|
||||
Maintain a single OpenAI client per base_url.
|
||||
"""
|
||||
return AsyncOpenAI(
|
||||
base_url=base_url,
|
||||
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
||||
timeout=self._config.timeout,
|
||||
)
|
||||
|
||||
special_model_urls = {
|
||||
"meta/llama-3.2-11b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct",
|
||||
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
|
||||
}
|
||||
|
||||
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||
|
||||
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
||||
base_url = special_model_urls[provider_model_id]
|
||||
return _get_client_for_base_url(base_url)
|
||||
return AsyncOpenAI(
|
||||
base_url=base_url,
|
||||
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
||||
timeout=self._config.timeout,
|
||||
)
|
||||
|
||||
async def _get_provider_model_id(self, model_id: str) -> str:
|
||||
if not self.model_store:
|
||||
|
|
@ -169,7 +148,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
response = await self._get_client(provider_model_id).completions.create(**request)
|
||||
response = await self._client.completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
|
|
@ -222,7 +201,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
extra_body["input_type"] = task_type_options[task_type]
|
||||
|
||||
try:
|
||||
response = await self._get_client(provider_model_id).embeddings.create(
|
||||
response = await self._client.embeddings.create(
|
||||
model=provider_model_id,
|
||||
input=input,
|
||||
extra_body=extra_body,
|
||||
|
|
@ -283,7 +262,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
response = await self._get_client(provider_model_id).chat.completions.create(**request)
|
||||
response = await self._client.chat.completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
|
|
@ -339,7 +318,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
return await self._get_client(provider_model_id).completions.create(**params)
|
||||
return await self._client.completions.create(**params)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
|
|
@ -398,7 +377,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
return await self._get_client(provider_model_id).chat.completions.create(**params)
|
||||
return await self._client.chat.completions.create(**params)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,19 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = [
|
||||
# The Llama Guard models don't have their full fp16 versions
|
||||
# so we are going to alias their default version to the canonical SKU
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:8b",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:1b",
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
),
|
||||
]
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1:8b-instruct-fp16",
|
||||
|
|
@ -73,16 +86,6 @@ MODEL_ENTRIES = [
|
|||
"llama3.3:70b",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
# The Llama Guard models don't have their full fp16 versions
|
||||
# so we are going to alias their default version to the canonical SKU
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:8b",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:1b",
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="all-minilm:l6-v2",
|
||||
aliases=["all-minilm"],
|
||||
|
|
@ -100,4 +103,4 @@ MODEL_ENTRIES = [
|
|||
"context_length": 8192,
|
||||
},
|
||||
),
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -48,16 +48,20 @@ EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = {
|
|||
"text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
|
||||
"text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
|
||||
}
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
|
||||
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
|
||||
ProviderModelEntry(
|
||||
provider_model_id=model_id,
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": model_info.embedding_dimension,
|
||||
"context_length": model_info.context_length,
|
||||
},
|
||||
)
|
||||
for model_id, model_info in EMBEDDING_MODEL_IDS.items()
|
||||
]
|
||||
MODEL_ENTRIES = (
|
||||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id=model_id,
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": model_info.embedding_dimension,
|
||||
"context_length": model_info.context_length,
|
||||
},
|
||||
)
|
||||
for model_id, model_info in EMBEDDING_MODEL_IDS.items()
|
||||
]
|
||||
+ SAFETY_MODELS_ENTRIES
|
||||
)
|
||||
|
|
|
|||
|
|
@ -59,9 +59,6 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
# if we do not set this, users will be exposed to the
|
||||
# litellm specific model names, an abstraction leak.
|
||||
self.is_openai_compat = True
|
||||
self._openai_client = AsyncOpenAI(
|
||||
api_key=self.config.api_key,
|
||||
)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
|
@ -69,6 +66,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
async def shutdown(self) -> None:
|
||||
await super().shutdown()
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
return AsyncOpenAI(
|
||||
api_key=self.get_api_key(),
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -120,7 +122,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
user=user,
|
||||
suffix=suffix,
|
||||
)
|
||||
return await self._openai_client.completions.create(**params)
|
||||
return await self._get_openai_client().completions.create(**params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
|
|
@ -176,7 +178,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return await self._openai_client.chat.completions.create(**params)
|
||||
return await self._get_openai_client().chat.completions.create(**params)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
|
|
@ -204,7 +206,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
params["user"] = user
|
||||
|
||||
# Call OpenAI embeddings API
|
||||
response = await self._openai_client.embeddings.create(**params)
|
||||
response = await self._get_openai_client().embeddings.create(**params)
|
||||
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from llama_stack.apis.inference import * # noqa: F403
|
|||
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
|
||||
|
||||
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
|
|
@ -25,6 +25,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import RunpodImplConfig
|
||||
|
||||
# https://docs.runpod.io/serverless/vllm/overview#compatible-models
|
||||
# https://github.com/runpod-workers/worker-vllm/blob/main/README.md#compatible-model-architectures
|
||||
RUNPOD_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
|
||||
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
|
||||
|
|
@ -40,6 +42,14 @@ RUNPOD_SUPPORTED_MODELS = {
|
|||
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
|
||||
}
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(provider_model_id, model_descriptor)
|
||||
for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
||||
|
||||
class RunpodInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,14 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.1-8B-Instruct",
|
||||
|
|
@ -46,8 +54,4 @@ MODEL_ENTRIES = [
|
|||
"sambanova/Llama-4-Maverick-17B-128E-Instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
import json
|
||||
from collections.abc import Iterable
|
||||
|
||||
import requests
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
|
||||
)
|
||||
|
|
@ -56,6 +57,7 @@ from llama_stack.apis.inference import (
|
|||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
|
@ -176,10 +178,11 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
|
||||
def __init__(self, config: SambaNovaImplConfig):
|
||||
self.config = config
|
||||
self.environment_available_models = []
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=self.config.api_key,
|
||||
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
|
||||
provider_data_api_key_field="sambanova_api_key",
|
||||
)
|
||||
|
||||
|
|
@ -246,6 +249,22 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
|
||||
list_models_url = self.config.url + "/models"
|
||||
if len(self.environment_available_models) == 0:
|
||||
try:
|
||||
response = requests.get(list_models_url)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise RuntimeError(f"Request to {list_models_url} failed") from e
|
||||
self.environment_available_models = [model.get("id") for model in response.json().get("data", {})]
|
||||
|
||||
if model_id.split("sambanova/")[-1] not in self.environment_available_models:
|
||||
logger.warning(f"Model {model_id} not available in {list_models_url}")
|
||||
return model
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,16 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
]
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
|
|
@ -40,14 +50,6 @@ MODEL_ENTRIES = [
|
|||
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Meta-Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval",
|
||||
model_type=ModelType.embedding,
|
||||
|
|
@ -78,4 +80,4 @@ MODEL_ENTRIES = [
|
|||
"together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
],
|
||||
),
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -68,19 +68,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
self.config = config
|
||||
self._client = None
|
||||
self._openai_client = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self._client:
|
||||
# Together client has no close method, so just set to None
|
||||
self._client = None
|
||||
if self._openai_client:
|
||||
await self._openai_client.close()
|
||||
self._openai_client = None
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
|
@ -108,29 +101,25 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
return await self._nonstream_completion(request)
|
||||
|
||||
def _get_client(self) -> AsyncTogether:
|
||||
if not self._client:
|
||||
together_api_key = None
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
together_api_key = config_api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
self._client = AsyncTogether(api_key=together_api_key)
|
||||
return self._client
|
||||
together_api_key = None
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
together_api_key = config_api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
return AsyncTogether(api_key=together_api_key)
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
if not self._openai_client:
|
||||
together_client = self._get_client().client
|
||||
self._openai_client = AsyncOpenAI(
|
||||
base_url=together_client.base_url,
|
||||
api_key=together_client.api_key,
|
||||
)
|
||||
return self._openai_client
|
||||
together_client = self._get_client().client
|
||||
return AsyncOpenAI(
|
||||
base_url=together_client.base_url,
|
||||
api_key=together_client.api_key,
|
||||
)
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
|||
class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData):
|
||||
def __init__(self, config: SambaNovaSafetyConfig) -> None:
|
||||
self.config = config
|
||||
self.environment_available_models = []
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
|
@ -54,18 +55,18 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
|
|||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
list_models_url = self.config.url + "/models"
|
||||
try:
|
||||
response = requests.get(list_models_url)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise RuntimeError(f"Request to {list_models_url} failed") from e
|
||||
available_models = [model.get("id") for model in response.json().get("data", {})]
|
||||
if len(self.environment_available_models) == 0:
|
||||
try:
|
||||
response = requests.get(list_models_url)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise RuntimeError(f"Request to {list_models_url} failed") from e
|
||||
self.environment_available_models = [model.get("id") for model in response.json().get("data", {})]
|
||||
if (
|
||||
len(available_models) == 0
|
||||
or "guard" not in shield.provider_resource_id.lower()
|
||||
or shield.provider_resource_id.split("sambanova/")[-1] not in available_models
|
||||
"guard" not in shield.provider_resource_id.lower()
|
||||
or shield.provider_resource_id.split("sambanova/")[-1] not in self.environment_available_models
|
||||
):
|
||||
raise ValueError(f"Shield {shield.provider_resource_id} not found in SambaNova")
|
||||
logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
|
||||
|
||||
async def run_shield(
|
||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
||||
|
|
|
|||
|
|
@ -217,7 +217,6 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
embedding_model: str | None = None,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorStoreObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -17,7 +17,7 @@ class MilvusVectorIOConfig(BaseModel):
|
|||
uri: str = Field(description="The URI of the Milvus server")
|
||||
token: str | None = Field(description="The token of the Milvus server")
|
||||
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
|
||||
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
|
||||
|
||||
# This configuration allows additional fields to be passed through to the underlying Milvus client.
|
||||
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
|
||||
|
|
@ -25,4 +25,11 @@ class MilvusVectorIOConfig(BaseModel):
|
|||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"}
|
||||
return {
|
||||
"uri": "${env.MILVUS_ENDPOINT}",
|
||||
"token": "${env.MILVUS_TOKEN}",
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="milvus_remote_registry.db",
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import re
|
|||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from pymilvus import DataType, MilvusClient
|
||||
from pymilvus import DataType, Function, FunctionType, MilvusClient
|
||||
|
||||
from llama_stack.apis.files.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
|
|
@ -61,6 +61,11 @@ class MilvusIndex(EmbeddingIndex):
|
|||
self.consistency_level = consistency_level
|
||||
self.kvstore = kvstore
|
||||
|
||||
async def initialize(self):
|
||||
# MilvusIndex does not require explicit initialization
|
||||
# TODO: could move collection creation into initialization but it is not really necessary
|
||||
pass
|
||||
|
||||
async def delete(self):
|
||||
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
|
||||
|
|
@ -69,12 +74,66 @@ class MilvusIndex(EmbeddingIndex):
|
|||
assert len(chunks) == len(embeddings), (
|
||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
)
|
||||
|
||||
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
||||
# Create schema for vector search
|
||||
schema = self.client.create_schema()
|
||||
schema.add_field(
|
||||
field_name="chunk_id",
|
||||
datatype=DataType.VARCHAR,
|
||||
is_primary=True,
|
||||
max_length=100,
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="content",
|
||||
datatype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_analyzer=True, # Enable text analysis for BM25
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="vector",
|
||||
datatype=DataType.FLOAT_VECTOR,
|
||||
dim=len(embeddings[0]),
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="chunk_content",
|
||||
datatype=DataType.JSON,
|
||||
)
|
||||
# Add sparse vector field for BM25 (required by the function)
|
||||
schema.add_field(
|
||||
field_name="sparse",
|
||||
datatype=DataType.SPARSE_FLOAT_VECTOR,
|
||||
)
|
||||
|
||||
# Create indexes
|
||||
index_params = self.client.prepare_index_params()
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
index_type="FLAT",
|
||||
metric_type="COSINE",
|
||||
)
|
||||
# Add index for sparse field (required by BM25 function)
|
||||
index_params.add_index(
|
||||
field_name="sparse",
|
||||
index_type="SPARSE_INVERTED_INDEX",
|
||||
metric_type="BM25",
|
||||
)
|
||||
|
||||
# Add BM25 function for full-text search
|
||||
bm25_function = Function(
|
||||
name="text_bm25_emb",
|
||||
input_field_names=["content"],
|
||||
output_field_names=["sparse"],
|
||||
function_type=FunctionType.BM25,
|
||||
)
|
||||
schema.add_function(bm25_function)
|
||||
|
||||
await asyncio.to_thread(
|
||||
self.client.create_collection,
|
||||
self.collection_name,
|
||||
dimension=len(embeddings[0]),
|
||||
auto_id=True,
|
||||
schema=schema,
|
||||
index_params=index_params,
|
||||
consistency_level=self.consistency_level,
|
||||
)
|
||||
|
||||
|
|
@ -83,8 +142,10 @@ class MilvusIndex(EmbeddingIndex):
|
|||
data.append(
|
||||
{
|
||||
"chunk_id": chunk.chunk_id,
|
||||
"content": chunk.content,
|
||||
"vector": embedding,
|
||||
"chunk_content": chunk.model_dump(),
|
||||
# sparse field will be handled by BM25 function automatically
|
||||
}
|
||||
)
|
||||
try:
|
||||
|
|
@ -102,6 +163,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
self.client.search,
|
||||
collection_name=self.collection_name,
|
||||
data=[embedding],
|
||||
anns_field="vector",
|
||||
limit=k,
|
||||
output_fields=["*"],
|
||||
search_params={"params": {"radius": score_threshold}},
|
||||
|
|
@ -116,7 +178,64 @@ class MilvusIndex(EmbeddingIndex):
|
|||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Milvus")
|
||||
"""
|
||||
Perform BM25-based keyword search using Milvus's built-in full-text search.
|
||||
"""
|
||||
try:
|
||||
# Use Milvus's built-in BM25 search
|
||||
search_res = await asyncio.to_thread(
|
||||
self.client.search,
|
||||
collection_name=self.collection_name,
|
||||
data=[query_string], # Raw text query
|
||||
anns_field="sparse", # Use sparse field for BM25
|
||||
output_fields=["chunk_content"], # Output the chunk content
|
||||
limit=k,
|
||||
search_params={
|
||||
"params": {
|
||||
"drop_ratio_search": 0.2, # Ignore low-importance terms
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for res in search_res[0]:
|
||||
chunk = Chunk(**res["entity"]["chunk_content"])
|
||||
chunks.append(chunk)
|
||||
scores.append(res["distance"]) # BM25 score from Milvus
|
||||
|
||||
# Filter by score threshold
|
||||
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
|
||||
filtered_scores = [score for score in scores if score >= score_threshold]
|
||||
|
||||
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing BM25 search: {e}")
|
||||
# Fallback to simple text search
|
||||
return await self._fallback_keyword_search(query_string, k, score_threshold)
|
||||
|
||||
async def _fallback_keyword_search(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
"""
|
||||
Fallback to simple text search when BM25 search is not available.
|
||||
"""
|
||||
# Simple text search using content field
|
||||
search_res = await asyncio.to_thread(
|
||||
self.client.query,
|
||||
collection_name=self.collection_name,
|
||||
filter='content like "%{content}%"',
|
||||
filter_params={"content": query_string},
|
||||
output_fields=["*"],
|
||||
limit=k,
|
||||
)
|
||||
chunks = [Chunk(**res["chunk_content"]) for res in search_res]
|
||||
scores = [1.0] * len(chunks) # Simple binary score for text search
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
|
|
@ -174,7 +293,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
uri = os.path.expanduser(self.config.db_path)
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
# Load existing OpenAI vector stores into the in-memory cache
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
|
|
@ -199,6 +319,9 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
if self.vector_db_store is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
|
@ -238,38 +361,16 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
if params and params.get("mode") == "keyword":
|
||||
# Check if this is inline Milvus (Milvus-Lite)
|
||||
if hasattr(self.config, "db_path"):
|
||||
raise NotImplementedError(
|
||||
"Keyword search is not supported in Milvus-Lite. "
|
||||
"Please use a remote Milvus server for keyword search functionality."
|
||||
)
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Save vector store metadata to persistent storage."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||
self.openai_vector_stores[store_id] = store_info
|
||||
|
||||
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Update vector store metadata in persistent storage."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||
self.openai_vector_stores[store_id] = store_info
|
||||
|
||||
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
||||
"""Delete vector store metadata from persistent storage."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.delete(key)
|
||||
if store_id in self.openai_vector_stores:
|
||||
del self.openai_vector_stores[store_id]
|
||||
|
||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||
"""Load all vector store metadata from persistent storage."""
|
||||
assert self.kvstore is not None
|
||||
start_key = OPENAI_VECTOR_STORES_PREFIX
|
||||
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
|
||||
stored = await self.kvstore.values_in_range(start_key, end_key)
|
||||
return {json.loads(s)["id"]: json.loads(s) for s in stored}
|
||||
|
||||
async def _save_openai_vector_store_file(
|
||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||
) -> None:
|
||||
|
|
|
|||
|
|
@ -8,6 +8,10 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -18,10 +22,12 @@ class PGVectorVectorIOConfig(BaseModel):
|
|||
db: str | None = Field(default="postgres")
|
||||
user: str | None = Field(default="postgres")
|
||||
password: str | None = Field(default="mysecretpassword")
|
||||
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
__distro_dir__: str,
|
||||
host: str = "${env.PGVECTOR_HOST:=localhost}",
|
||||
port: int = "${env.PGVECTOR_PORT:=5432}",
|
||||
db: str = "${env.PGVECTOR_DB}",
|
||||
|
|
@ -29,4 +35,14 @@ class PGVectorVectorIOConfig(BaseModel):
|
|||
password: str = "${env.PGVECTOR_PASSWORD}",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
return {"host": host, "port": port, "db": db, "user": user, "password": password}
|
||||
return {
|
||||
"host": host,
|
||||
"port": port,
|
||||
"db": db,
|
||||
"user": user,
|
||||
"password": password,
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="pgvector_registry.db",
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,24 +13,18 @@ from psycopg2 import sql
|
|||
from psycopg2.extras import Json, execute_values
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.files.files import Files
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
SearchRankingOptions,
|
||||
VectorIO,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFileStatus,
|
||||
VectorStoreListFilesResponse,
|
||||
VectorStoreListResponse,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
|
|
@ -40,6 +34,13 @@ from .config import PGVectorVectorIOConfig
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:pgvector:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:pgvector:{VERSION}::"
|
||||
|
||||
|
||||
def check_extension_version(cur):
|
||||
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
|
||||
|
|
@ -69,7 +70,7 @@ def load_models(cur, cls):
|
|||
|
||||
|
||||
class PGVectorIndex(EmbeddingIndex):
|
||||
def __init__(self, vector_db: VectorDB, dimension: int, conn):
|
||||
def __init__(self, vector_db: VectorDB, dimension: int, conn, kvstore: KVStore | None = None):
|
||||
self.conn = conn
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
# Sanitize the table name by replacing hyphens with underscores
|
||||
|
|
@ -77,6 +78,7 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
# when created with patterns like "test-vector-db-{uuid4()}"
|
||||
sanitized_identifier = vector_db.identifier.replace("-", "_")
|
||||
self.table_name = f"vector_store_{sanitized_identifier}"
|
||||
self.kvstore = kvstore
|
||||
|
||||
cur.execute(
|
||||
f"""
|
||||
|
|
@ -158,15 +160,28 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
||||
|
||||
class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: PGVectorVectorIOConfig, inference_api: Api.inference) -> None:
|
||||
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: PGVectorVectorIOConfig,
|
||||
inference_api: Api.inference,
|
||||
files_api: Files | None = None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.conn = None
|
||||
self.cache = {}
|
||||
self.files_api = files_api
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
self.openai_vector_store: dict[str, dict[str, Any]] = {}
|
||||
self.metadatadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
||||
try:
|
||||
self.conn = psycopg2.connect(
|
||||
host=self.config.host,
|
||||
|
|
@ -201,14 +216,28 @@ class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
log.info("Connection to PGVector database server closed")
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
# Persist vector DB metadata in the KV store
|
||||
assert self.kvstore is not None
|
||||
# Upsert model metadata in Postgres
|
||||
upsert_models(self.conn, [(vector_db.identifier, vector_db)])
|
||||
|
||||
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
# Create and cache the PGVector index table for the vector DB
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
index=PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn, kvstore=self.kvstore),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
# Remove provider index and cache
|
||||
if vector_db_id in self.cache:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
# Delete vector DB metadata from KV store
|
||||
assert self.kvstore is not None
|
||||
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
|
|
@ -237,107 +266,124 @@ class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name: str,
|
||||
file_ids: list[str] | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
chunking_strategy: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
embedding_model: str | None = None,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorStoreObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
# OpenAI Vector Stores File operations are not supported in PGVector
|
||||
async def _save_openai_vector_store_file(
|
||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Save vector store file metadata to Postgres database."""
|
||||
if self.conn is None:
|
||||
raise RuntimeError("PostgreSQL connection is not initialized")
|
||||
try:
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS openai_vector_store_files (
|
||||
store_id TEXT,
|
||||
file_id TEXT,
|
||||
metadata JSONB,
|
||||
PRIMARY KEY (store_id, file_id)
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS openai_vector_store_files_contents (
|
||||
store_id TEXT,
|
||||
file_id TEXT,
|
||||
contents JSONB,
|
||||
PRIMARY KEY (store_id, file_id)
|
||||
)
|
||||
"""
|
||||
)
|
||||
# Insert file metadata
|
||||
files_query = sql.SQL(
|
||||
"""
|
||||
INSERT INTO openai_vector_store_files (store_id, file_id, metadata)
|
||||
VALUES %s
|
||||
ON CONFLICT (store_id, file_id) DO UPDATE SET metadata = EXCLUDED.metadata
|
||||
"""
|
||||
)
|
||||
files_values = [(store_id, file_id, Json(file_info))]
|
||||
execute_values(cur, files_query, files_values, template="(%s, %s, %s)")
|
||||
# Insert file contents
|
||||
contents_query = sql.SQL(
|
||||
"""
|
||||
INSERT INTO openai_vector_store_files_contents (store_id, file_id, contents)
|
||||
VALUES %s
|
||||
ON CONFLICT (store_id, file_id) DO UPDATE SET contents = EXCLUDED.contents
|
||||
"""
|
||||
)
|
||||
contents_values = [(store_id, file_id, Json(file_contents))]
|
||||
execute_values(cur, contents_query, contents_values, template="(%s, %s, %s)")
|
||||
except Exception as e:
|
||||
log.error(f"Error saving openai vector store file {file_id} for store {store_id}: {e}")
|
||||
raise
|
||||
|
||||
async def openai_list_vector_stores(
|
||||
self,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
) -> VectorStoreListResponse:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
|
||||
"""Load vector store file metadata from Postgres database."""
|
||||
if self.conn is None:
|
||||
raise RuntimeError("PostgreSQL connection is not initialized")
|
||||
try:
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(
|
||||
"SELECT metadata FROM openai_vector_store_files WHERE store_id = %s AND file_id = %s",
|
||||
(store_id, file_id),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
return row[0] if row and row[0] is not None else {}
|
||||
except Exception as e:
|
||||
log.error(f"Error loading openai vector store file {file_id} for store {store_id}: {e}")
|
||||
return {}
|
||||
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
||||
"""Load vector store file contents from Postgres database."""
|
||||
if self.conn is None:
|
||||
raise RuntimeError("PostgreSQL connection is not initialized")
|
||||
try:
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(
|
||||
"SELECT contents FROM openai_vector_store_files_contents WHERE store_id = %s AND file_id = %s",
|
||||
(store_id, file_id),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
return row[0] if row and row[0] is not None else []
|
||||
except Exception as e:
|
||||
log.error(f"Error loading openai vector store file contents for {file_id} in store {store_id}: {e}")
|
||||
return []
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
name: str | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
||||
"""Update vector store file metadata in Postgres database."""
|
||||
if self.conn is None:
|
||||
raise RuntimeError("PostgreSQL connection is not initialized")
|
||||
try:
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
query = sql.SQL(
|
||||
"""
|
||||
INSERT INTO openai_vector_store_files (store_id, file_id, metadata)
|
||||
VALUES %s
|
||||
ON CONFLICT (store_id, file_id) DO UPDATE SET metadata = EXCLUDED.metadata
|
||||
"""
|
||||
)
|
||||
values = [(store_id, file_id, Json(file_info))]
|
||||
execute_values(cur, query, values, template="(%s, %s, %s)")
|
||||
except Exception as e:
|
||||
log.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}")
|
||||
raise
|
||||
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: SearchRankingOptions | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
|
||||
async def openai_list_files_in_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> VectorStoreListFilesResponse:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
|
||||
async def openai_retrieve_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
|
||||
async def openai_retrieve_vector_store_file_contents(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
|
||||
async def openai_delete_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||
"""Delete vector store file metadata from Postgres database."""
|
||||
if self.conn is None:
|
||||
raise RuntimeError("PostgreSQL connection is not initialized")
|
||||
try:
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(
|
||||
"DELETE FROM openai_vector_store_files WHERE store_id = %s AND file_id = %s",
|
||||
(store_id, file_id),
|
||||
)
|
||||
cur.execute(
|
||||
"DELETE FROM openai_vector_store_files_contents WHERE store_id = %s AND file_id = %s",
|
||||
(store_id, file_id),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting openai vector store file {file_id} for store {store_id}: {e}")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -214,7 +214,6 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
embedding_model: str | None = None,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorStoreObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
|
||||
|
|
|
|||
|
|
@ -6,15 +6,26 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
|
||||
|
||||
class WeaviateRequestProviderData(BaseModel):
|
||||
weaviate_api_key: str
|
||||
weaviate_cluster_url: str
|
||||
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
||||
|
||||
|
||||
class WeaviateVectorIOConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="weaviate_registry.db",
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,10 +14,13 @@ from weaviate.classes.init import Auth
|
|||
from weaviate.classes.query import Filter
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.files.files import Files
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
|
|
@ -27,11 +30,19 @@ from .config import WeaviateRequestProviderData, WeaviateVectorIOConfig
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:weaviate:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:weaviate:{VERSION}::"
|
||||
|
||||
|
||||
class WeaviateIndex(EmbeddingIndex):
|
||||
def __init__(self, client: weaviate.Client, collection_name: str):
|
||||
def __init__(self, client: weaviate.Client, collection_name: str, kvstore: KVStore | None = None):
|
||||
self.client = client
|
||||
self.collection_name = collection_name
|
||||
self.kvstore = kvstore
|
||||
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(embeddings), (
|
||||
|
|
@ -109,11 +120,21 @@ class WeaviateVectorIOAdapter(
|
|||
NeedsRequestProviderData,
|
||||
VectorDBsProtocolPrivate,
|
||||
):
|
||||
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Api.inference) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
config: WeaviateVectorIOConfig,
|
||||
inference_api: Api.inference,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.client_cache = {}
|
||||
self.cache = {}
|
||||
self.files_api = files_api
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
def _get_client(self) -> weaviate.Client:
|
||||
provider_data = self.get_request_provider_data()
|
||||
|
|
@ -132,7 +153,26 @@ class WeaviateVectorIOAdapter(
|
|||
return client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
"""Set up KV store and load existing vector DBs and OpenAI vector stores."""
|
||||
# Initialize KV store for metadata
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
|
||||
# Load existing vector DB definitions
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored = await self.kvstore.values_in_range(start_key, end_key)
|
||||
for raw in stored:
|
||||
vector_db = VectorDB.model_validate_json(raw)
|
||||
client = self._get_client()
|
||||
idx = WeaviateIndex(client=client, collection_name=vector_db.identifier, kvstore=self.kvstore)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=idx,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
# Load OpenAI vector stores metadata into cache
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for client in self.client_cache.values():
|
||||
|
|
@ -206,3 +246,21 @@ class WeaviateVectorIOAdapter(
|
|||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
# OpenAI Vector Stores File operations are not supported in Weaviate
|
||||
async def _save_openai_vector_store_file(
|
||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||
) -> None:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||
|
||||
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||
|
||||
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||
|
||||
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||
|
||||
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ def build_hf_repo_model_entry(
|
|||
]
|
||||
if additional_aliases:
|
||||
aliases.extend(additional_aliases)
|
||||
aliases = [alias for alias in aliases if alias is not None]
|
||||
return ProviderModelEntry(
|
||||
provider_model_id=provider_model_id,
|
||||
aliases=aliases,
|
||||
|
|
@ -82,15 +83,43 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
def get_llama_model(self, provider_model_id: str) -> str | None:
|
||||
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available from the provider (non-static check).
|
||||
|
||||
This is for subclassing purposes, so providers can check if a specific
|
||||
model is currently available for use through dynamic means (e.g., API calls).
|
||||
|
||||
This method should NOT check statically configured model entries in
|
||||
`self.alias_to_provider_id_map` - that is handled separately in register_model.
|
||||
|
||||
Default implementation returns False (no dynamic models available).
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
if not (supported_model_id := self.get_provider_model_id(model.provider_resource_id)):
|
||||
raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys())
|
||||
# Check if model is supported in static configuration
|
||||
supported_model_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
|
||||
# If not found in static config, check if it's available dynamically from provider
|
||||
if not supported_model_id:
|
||||
if await self.check_model_availability(model.provider_resource_id):
|
||||
supported_model_id = model.provider_resource_id
|
||||
else:
|
||||
# note: we cannot provide a complete list of supported models without
|
||||
# getting a complete list from the provider, so we return "..."
|
||||
all_supported_models = [*self.alias_to_provider_id_map.keys(), "..."]
|
||||
raise UnsupportedModelError(model.provider_resource_id, all_supported_models)
|
||||
|
||||
provider_resource_id = self.get_provider_model_id(model.model_id)
|
||||
if model.model_type == ModelType.embedding:
|
||||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||
provider_resource_id = model.provider_resource_id
|
||||
if provider_resource_id:
|
||||
if provider_resource_id != supported_model_id: # be idemopotent, only reject differences
|
||||
if provider_resource_id != supported_model_id: # be idempotent, only reject differences
|
||||
raise ValueError(
|
||||
f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first."
|
||||
)
|
||||
|
|
@ -113,6 +142,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||
)
|
||||
|
||||
# Register the model alias, ensuring it maps to the correct provider model id
|
||||
self.alias_to_provider_id_map[model.model_id] = supported_model_id
|
||||
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import time
|
||||
|
|
@ -35,6 +36,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -59,26 +61,45 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
# These should be provided by the implementing class
|
||||
openai_vector_stores: dict[str, dict[str, Any]]
|
||||
files_api: Files | None
|
||||
# KV store for persisting OpenAI vector store metadata
|
||||
kvstore: KVStore | None
|
||||
|
||||
@abstractmethod
|
||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Save vector store metadata to persistent storage."""
|
||||
pass
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||
# update in-memory cache
|
||||
self.openai_vector_stores[store_id] = store_info
|
||||
|
||||
@abstractmethod
|
||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||
"""Load all vector store metadata from persistent storage."""
|
||||
pass
|
||||
assert self.kvstore is not None
|
||||
start_key = OPENAI_VECTOR_STORES_PREFIX
|
||||
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
|
||||
stored_data = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
stores: dict[str, dict[str, Any]] = {}
|
||||
for item in stored_data:
|
||||
info = json.loads(item)
|
||||
stores[info["id"]] = info
|
||||
return stores
|
||||
|
||||
@abstractmethod
|
||||
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Update vector store metadata in persistent storage."""
|
||||
pass
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||
# update in-memory cache
|
||||
self.openai_vector_stores[store_id] = store_info
|
||||
|
||||
@abstractmethod
|
||||
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
||||
"""Delete vector store metadata from persistent storage."""
|
||||
pass
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.delete(key)
|
||||
# remove from in-memory cache
|
||||
self.openai_vector_stores.pop(store_id, None)
|
||||
|
||||
@abstractmethod
|
||||
async def _save_openai_vector_store_file(
|
||||
|
|
@ -117,6 +138,10 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
"""Unregister a vector database (provider-specific implementation)."""
|
||||
pass
|
||||
|
||||
async def initialize_openai_vector_stores(self) -> None:
|
||||
"""Load existing OpenAI vector stores into the in-memory cache."""
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
||||
@abstractmethod
|
||||
async def insert_chunks(
|
||||
self,
|
||||
|
|
@ -147,8 +172,9 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorStoreObject:
|
||||
"""Creates a vector store."""
|
||||
store_id = name or str(uuid.uuid4())
|
||||
created_at = int(time.time())
|
||||
# Derive the canonical vector_db_id (allow override, else generate)
|
||||
vector_db_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
|
||||
|
||||
if provider_id is None:
|
||||
raise ValueError("Provider ID is required")
|
||||
|
|
@ -156,19 +182,19 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
if embedding_model is None:
|
||||
raise ValueError("Embedding model is required")
|
||||
|
||||
# Use provided embedding dimension or default to 384
|
||||
# Embedding dimension is required (defaulted to 384 if not provided)
|
||||
if embedding_dimension is None:
|
||||
raise ValueError("Embedding dimension is required")
|
||||
|
||||
provider_vector_db_id = provider_vector_db_id or store_id
|
||||
# Register the VectorDB backing this vector store
|
||||
vector_db = VectorDB(
|
||||
identifier=store_id,
|
||||
identifier=vector_db_id,
|
||||
embedding_dimension=embedding_dimension,
|
||||
embedding_model=embedding_model,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=provider_vector_db_id,
|
||||
provider_resource_id=vector_db_id,
|
||||
vector_db_name=name,
|
||||
)
|
||||
# Register the vector DB
|
||||
await self.register_vector_db(vector_db)
|
||||
|
||||
# Create OpenAI vector store metadata
|
||||
|
|
@ -182,11 +208,11 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
in_progress=0,
|
||||
total=0,
|
||||
)
|
||||
store_info = {
|
||||
"id": store_id,
|
||||
store_info: dict[str, Any] = {
|
||||
"id": vector_db_id,
|
||||
"object": "vector_store",
|
||||
"created_at": created_at,
|
||||
"name": store_id,
|
||||
"name": name,
|
||||
"usage_bytes": 0,
|
||||
"file_counts": file_counts.model_dump(),
|
||||
"status": status,
|
||||
|
|
@ -206,18 +232,18 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
store_info["metadata"] = metadata
|
||||
|
||||
# Save to persistent storage (provider-specific)
|
||||
await self._save_openai_vector_store(store_id, store_info)
|
||||
await self._save_openai_vector_store(vector_db_id, store_info)
|
||||
|
||||
# Store in memory cache
|
||||
self.openai_vector_stores[store_id] = store_info
|
||||
self.openai_vector_stores[vector_db_id] = store_info
|
||||
|
||||
# Now that our vector store is created, attach any files that were provided
|
||||
file_ids = file_ids or []
|
||||
tasks = [self.openai_attach_file_to_vector_store(store_id, file_id) for file_id in file_ids]
|
||||
tasks = [self.openai_attach_file_to_vector_store(vector_db_id, file_id) for file_id in file_ids]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Get the updated store info and return it
|
||||
store_info = self.openai_vector_stores[store_id]
|
||||
store_info = self.openai_vector_stores[vector_db_id]
|
||||
return VectorStoreObject.model_validate(store_info)
|
||||
|
||||
async def openai_list_vector_stores(
|
||||
|
|
|
|||
|
|
@ -39,22 +39,10 @@ SQL_OPTIMIZED_POLICY = [
|
|||
|
||||
|
||||
class SqlRecord(ProtectedResource):
|
||||
"""Simple ProtectedResource implementation for SQL records."""
|
||||
|
||||
def __init__(self, record_id: str, table_name: str, access_attributes: dict[str, list[str]] | None = None):
|
||||
def __init__(self, record_id: str, table_name: str, owner: User):
|
||||
self.type = f"sql_record::{table_name}"
|
||||
self.identifier = record_id
|
||||
|
||||
if access_attributes:
|
||||
self.owner = User(
|
||||
principal="system",
|
||||
attributes=access_attributes,
|
||||
)
|
||||
else:
|
||||
self.owner = User(
|
||||
principal="system_public",
|
||||
attributes=None,
|
||||
)
|
||||
self.owner = owner
|
||||
|
||||
|
||||
class AuthorizedSqlStore:
|
||||
|
|
@ -101,22 +89,27 @@ class AuthorizedSqlStore:
|
|||
|
||||
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
|
||||
"""Create a table with built-in access control support."""
|
||||
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
|
||||
|
||||
enhanced_schema = dict(schema)
|
||||
if "access_attributes" not in enhanced_schema:
|
||||
enhanced_schema["access_attributes"] = ColumnType.JSON
|
||||
if "owner_principal" not in enhanced_schema:
|
||||
enhanced_schema["owner_principal"] = ColumnType.STRING
|
||||
|
||||
await self.sql_store.create_table(table, enhanced_schema)
|
||||
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
|
||||
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
||||
"""Insert a row with automatic access control attribute capture."""
|
||||
enhanced_data = dict(data)
|
||||
|
||||
current_user = get_authenticated_user()
|
||||
if current_user and current_user.attributes:
|
||||
if current_user:
|
||||
enhanced_data["owner_principal"] = current_user.principal
|
||||
enhanced_data["access_attributes"] = current_user.attributes
|
||||
else:
|
||||
enhanced_data["owner_principal"] = None
|
||||
enhanced_data["access_attributes"] = None
|
||||
|
||||
await self.sql_store.insert(table, enhanced_data)
|
||||
|
|
@ -146,9 +139,12 @@ class AuthorizedSqlStore:
|
|||
|
||||
for row in rows.data:
|
||||
stored_access_attrs = row.get("access_attributes")
|
||||
stored_owner_principal = row.get("owner_principal") or ""
|
||||
|
||||
record_id = row.get("id", "unknown")
|
||||
sql_record = SqlRecord(str(record_id), table, stored_access_attrs)
|
||||
sql_record = SqlRecord(
|
||||
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
|
||||
)
|
||||
|
||||
if is_action_allowed(policy, Action.READ, sql_record, current_user):
|
||||
filtered_rows.append(row)
|
||||
|
|
@ -186,8 +182,10 @@ class AuthorizedSqlStore:
|
|||
Only applies SQL filtering for the default policy to ensure correctness.
|
||||
For custom policies, uses conservative filtering to avoid blocking legitimate access.
|
||||
"""
|
||||
current_user = get_authenticated_user()
|
||||
|
||||
if not policy or policy == SQL_OPTIMIZED_POLICY:
|
||||
return self._build_default_policy_where_clause()
|
||||
return self._build_default_policy_where_clause(current_user)
|
||||
else:
|
||||
return self._build_conservative_where_clause()
|
||||
|
||||
|
|
@ -227,29 +225,27 @@ class AuthorizedSqlStore:
|
|||
|
||||
def _get_public_access_conditions(self) -> list[str]:
|
||||
"""Get the SQL conditions for public access."""
|
||||
# Public records are records that have no owner_principal or access_attributes
|
||||
conditions = ["owner_principal = ''"]
|
||||
if self.database_type == SqlStoreType.postgres:
|
||||
# Postgres stores JSON null as 'null'
|
||||
return ["access_attributes::text = 'null'"]
|
||||
conditions.append("access_attributes::text = 'null'")
|
||||
elif self.database_type == SqlStoreType.sqlite:
|
||||
return ["access_attributes = 'null'"]
|
||||
conditions.append("access_attributes = 'null'")
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
return conditions
|
||||
|
||||
def _build_default_policy_where_clause(self) -> str:
|
||||
def _build_default_policy_where_clause(self, current_user: User | None) -> str:
|
||||
"""Build SQL WHERE clause for the default policy.
|
||||
|
||||
Default policy: permit all actions when user in owners [roles, teams, projects, namespaces]
|
||||
This means user must match ALL attribute categories that exist in the resource.
|
||||
"""
|
||||
current_user = get_authenticated_user()
|
||||
|
||||
base_conditions = self._get_public_access_conditions()
|
||||
if not current_user or not current_user.attributes:
|
||||
# Only allow public records
|
||||
return f"({' OR '.join(base_conditions)})"
|
||||
else:
|
||||
user_attr_conditions = []
|
||||
user_attr_conditions = []
|
||||
|
||||
if current_user and current_user.attributes:
|
||||
for attr_key, user_values in current_user.attributes.items():
|
||||
if user_values:
|
||||
value_conditions = []
|
||||
|
|
@ -269,7 +265,7 @@ class AuthorizedSqlStore:
|
|||
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
|
||||
base_conditions.append(all_requirements_met)
|
||||
|
||||
return f"({' OR '.join(base_conditions)})"
|
||||
return f"({' OR '.join(base_conditions)})"
|
||||
|
||||
def _build_conservative_where_clause(self) -> str:
|
||||
"""Conservative SQL filtering for custom policies.
|
||||
|
|
|
|||
|
|
@ -244,35 +244,41 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
engine = create_async_engine(self.config.engine_str)
|
||||
|
||||
try:
|
||||
inspector = inspect(engine)
|
||||
|
||||
table_names = inspector.get_table_names()
|
||||
if table not in table_names:
|
||||
return
|
||||
|
||||
existing_columns = inspector.get_columns(table)
|
||||
column_names = [col["name"] for col in existing_columns]
|
||||
|
||||
if column_name in column_names:
|
||||
return
|
||||
|
||||
sqlalchemy_type = TYPE_MAPPING.get(column_type)
|
||||
if not sqlalchemy_type:
|
||||
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
|
||||
|
||||
# Create the ALTER TABLE statement
|
||||
# Note: We need to get the dialect-specific type name
|
||||
dialect = engine.dialect
|
||||
type_impl = sqlalchemy_type()
|
||||
compiled_type = type_impl.compile(dialect=dialect)
|
||||
|
||||
nullable_clause = "" if nullable else " NOT NULL"
|
||||
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
|
||||
|
||||
async with engine.begin() as conn:
|
||||
|
||||
def check_column_exists(sync_conn):
|
||||
inspector = inspect(sync_conn)
|
||||
|
||||
table_names = inspector.get_table_names()
|
||||
if table not in table_names:
|
||||
return False, False # table doesn't exist, column doesn't exist
|
||||
|
||||
existing_columns = inspector.get_columns(table)
|
||||
column_names = [col["name"] for col in existing_columns]
|
||||
|
||||
return True, column_name in column_names # table exists, column exists or not
|
||||
|
||||
table_exists, column_exists = await conn.run_sync(check_column_exists)
|
||||
if not table_exists or column_exists:
|
||||
return
|
||||
|
||||
sqlalchemy_type = TYPE_MAPPING.get(column_type)
|
||||
if not sqlalchemy_type:
|
||||
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
|
||||
|
||||
# Create the ALTER TABLE statement
|
||||
# Note: We need to get the dialect-specific type name
|
||||
dialect = engine.dialect
|
||||
type_impl = sqlalchemy_type()
|
||||
compiled_type = type_impl.compile(dialect=dialect)
|
||||
|
||||
nullable_clause = "" if nullable else " NOT NULL"
|
||||
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
|
||||
|
||||
await conn.execute(add_column_sql)
|
||||
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
# If any error occurs during migration, log it but don't fail
|
||||
# The table creation will handle adding the column
|
||||
logger.error(f"Error adding column {column_name} to table {table}: {e}")
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -9,14 +9,12 @@ import inspect
|
|||
import json
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from functools import wraps
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.models.llama.datatypes import Primitive
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def serialize_value(value: Any) -> Primitive:
|
||||
return str(_prepare_for_json(value))
|
||||
|
|
@ -44,7 +42,7 @@ def _prepare_for_json(value: Any) -> str:
|
|||
return str(value)
|
||||
|
||||
|
||||
def trace_protocol(cls: type[T]) -> type[T]:
|
||||
def trace_protocol[T](cls: type[T]) -> type[T]:
|
||||
"""
|
||||
A class decorator that automatically traces all methods in a protocol/base class
|
||||
and its inheriting classes.
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
),
|
||||
]
|
||||
|
||||
default_models = get_model_registry(available_models)
|
||||
default_models, _ = get_model_registry(available_models)
|
||||
return DistributionTemplate(
|
||||
name="nvidia",
|
||||
distro_type="self_hosted",
|
||||
|
|
|
|||
|
|
@ -128,6 +128,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_id="${env.ENABLE_PGVECTOR:+pgvector}",
|
||||
provider_type="remote::pgvector",
|
||||
config=PGVectorVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}",
|
||||
db="${env.PGVECTOR_DB:=}",
|
||||
user="${env.PGVECTOR_USER:=}",
|
||||
password="${env.PGVECTOR_PASSWORD:=}",
|
||||
|
|
@ -146,7 +147,8 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
),
|
||||
]
|
||||
|
||||
default_models = get_model_registry(available_models) + [
|
||||
models, _ = get_model_registry(available_models)
|
||||
default_models = models + [
|
||||
ModelInput(
|
||||
model_id="meta-llama/Llama-3.3-70B-Instruct",
|
||||
provider_id="groq",
|
||||
|
|
|
|||
|
|
@ -39,6 +39,9 @@ providers:
|
|||
provider_type: inline::sqlite-vec
|
||||
config:
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/sqlite_vec.db
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/sqlite_vec_registry.db
|
||||
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
|
|
@ -51,6 +54,9 @@ providers:
|
|||
db: ${env.PGVECTOR_DB:=}
|
||||
user: ${env.PGVECTOR_USER:=}
|
||||
password: ${env.PGVECTOR_PASSWORD:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/pgvector_registry.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
|
|
|
|||
|
|
@ -144,6 +144,9 @@ providers:
|
|||
provider_type: inline::sqlite-vec
|
||||
config:
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
|
||||
- provider_id: ${env.ENABLE_MILVUS:=__disabled__}
|
||||
provider_type: inline::milvus
|
||||
config:
|
||||
|
|
@ -163,6 +166,9 @@ providers:
|
|||
db: ${env.PGVECTOR_DB:=}
|
||||
user: ${env.PGVECTOR_USER:=}
|
||||
password: ${env.PGVECTOR_PASSWORD:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
|
|
@ -256,11 +262,46 @@ inference_store:
|
|||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/llama3.1-8b
|
||||
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
|
||||
provider_model_id: llama3.1-8b
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct
|
||||
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
|
||||
provider_model_id: llama3.1-8b
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/llama-3.3-70b
|
||||
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
|
||||
provider_model_id: llama-3.3-70b
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/meta-llama/Llama-3.3-70B-Instruct
|
||||
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
|
||||
provider_model_id: llama-3.3-70b
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/llama-4-scout-17b-16e-instruct
|
||||
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
|
||||
provider_model_id: llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
|
||||
provider_model_id: llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.OLLAMA_INFERENCE_MODEL:=__disabled__}
|
||||
provider_id: ${env.ENABLE_OLLAMA:=__disabled__}
|
||||
provider_model_id: ${env.OLLAMA_INFERENCE_MODEL:=__disabled__}
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__}
|
||||
provider_id: ${env.ENABLE_OLLAMA:=__disabled__}
|
||||
provider_model_id: ${env.SAFETY_MODEL:=__disabled__}
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: ${env.OLLAMA_EMBEDDING_DIMENSION:=384}
|
||||
model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}
|
||||
|
|
@ -342,26 +383,6 @@ models:
|
|||
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
|
||||
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-8b
|
||||
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
|
||||
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-8B
|
||||
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
|
||||
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-11b-vision
|
||||
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
|
||||
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision
|
||||
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
|
||||
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama4-scout-instruct-basic
|
||||
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
|
||||
|
|
@ -389,6 +410,26 @@ models:
|
|||
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
|
||||
provider_model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
model_type: embedding
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-8b
|
||||
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
|
||||
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-8B
|
||||
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
|
||||
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-11b-vision
|
||||
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
|
||||
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision
|
||||
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
|
||||
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo
|
||||
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
|
||||
|
|
@ -459,26 +500,6 @@ models:
|
|||
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
|
||||
provider_model_id: meta-llama/Llama-3.3-70B-Instruct-Turbo
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Meta-Llama-Guard-3-8B
|
||||
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
|
||||
provider_model_id: meta-llama/Meta-Llama-Guard-3-8B
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-8B
|
||||
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
|
||||
provider_model_id: meta-llama/Meta-Llama-Guard-3-8B
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision-Turbo
|
||||
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
|
||||
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision
|
||||
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
|
||||
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 768
|
||||
context_length: 8192
|
||||
|
|
@ -523,6 +544,264 @@ models:
|
|||
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
|
||||
provider_model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-8B
|
||||
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
|
||||
provider_model_id: meta-llama/Llama-Guard-3-8B
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-8B
|
||||
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
|
||||
provider_model_id: meta-llama/Llama-Guard-3-8B
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision-Turbo
|
||||
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
|
||||
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision
|
||||
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
|
||||
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta.llama3-1-8b-instruct-v1:0
|
||||
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
|
||||
provider_model_id: meta.llama3-1-8b-instruct-v1:0
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct
|
||||
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
|
||||
provider_model_id: meta.llama3-1-8b-instruct-v1:0
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta.llama3-1-70b-instruct-v1:0
|
||||
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
|
||||
provider_model_id: meta.llama3-1-70b-instruct-v1:0
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta-llama/Llama-3.1-70B-Instruct
|
||||
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
|
||||
provider_model_id: meta.llama3-1-70b-instruct-v1:0
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta.llama3-1-405b-instruct-v1:0
|
||||
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
|
||||
provider_model_id: meta.llama3-1-405b-instruct-v1:0
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta-llama/Llama-3.1-405B-Instruct-FP8
|
||||
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
|
||||
provider_model_id: meta.llama3-1-405b-instruct-v1:0
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/databricks-meta-llama-3-1-70b-instruct
|
||||
provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
|
||||
provider_model_id: databricks-meta-llama-3-1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/meta-llama/Llama-3.1-70B-Instruct
|
||||
provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
|
||||
provider_model_id: databricks-meta-llama-3-1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/databricks-meta-llama-3-1-405b-instruct
|
||||
provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
|
||||
provider_model_id: databricks-meta-llama-3-1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/meta-llama/Llama-3.1-405B-Instruct-FP8
|
||||
provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
|
||||
provider_model_id: databricks-meta-llama-3-1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama3-8b-instruct
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: meta/llama3-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3-8B-Instruct
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: meta/llama3-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama3-70b-instruct
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: meta/llama3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3-70B-Instruct
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: meta/llama3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.1-8b-instruct
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: meta/llama-3.1-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: meta/llama-3.1-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.1-70b-instruct
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: meta/llama-3.1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.1-70B-Instruct
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: meta/llama-3.1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.1-405b-instruct
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: meta/llama-3.1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.1-405B-Instruct-FP8
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: meta/llama-3.1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-1b-instruct
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: meta/llama-3.2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-1B-Instruct
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: meta/llama-3.2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-3b-instruct
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: meta/llama-3.2-3b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-3B-Instruct
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: meta/llama-3.2-3b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-11b-vision-instruct
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: meta/llama-3.2-11b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: meta/llama-3.2-11b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-90b-vision-instruct
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-90B-Vision-Instruct
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.3-70b-instruct
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: meta/llama-3.3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.3-70B-Instruct
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: meta/llama-3.3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 2048
|
||||
context_length: 8192
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/nvidia/llama-3.2-nv-embedqa-1b-v2
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
|
||||
model_type: embedding
|
||||
- metadata:
|
||||
embedding_dimension: 1024
|
||||
context_length: 512
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/nvidia/nv-embedqa-e5-v5
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: nvidia/nv-embedqa-e5-v5
|
||||
model_type: embedding
|
||||
- metadata:
|
||||
embedding_dimension: 4096
|
||||
context_length: 512
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/nvidia/nv-embedqa-mistral-7b-v2
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: nvidia/nv-embedqa-mistral-7b-v2
|
||||
model_type: embedding
|
||||
- metadata:
|
||||
embedding_dimension: 1024
|
||||
context_length: 512
|
||||
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/snowflake/arctic-embed-l
|
||||
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
|
||||
provider_model_id: snowflake/arctic-embed-l
|
||||
model_type: embedding
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-8B
|
||||
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
|
||||
provider_model_id: Llama3.1-8B
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-70B
|
||||
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
|
||||
provider_model_id: Llama3.1-70B
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B:bf16-mp8
|
||||
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
|
||||
provider_model_id: Llama3.1-405B:bf16-mp8
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B
|
||||
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
|
||||
provider_model_id: Llama3.1-405B
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B:bf16-mp16
|
||||
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
|
||||
provider_model_id: Llama3.1-405B:bf16-mp16
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-8B-Instruct
|
||||
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
|
||||
provider_model_id: Llama3.1-8B-Instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-70B-Instruct
|
||||
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
|
||||
provider_model_id: Llama3.1-70B-Instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B-Instruct:bf16-mp8
|
||||
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
|
||||
provider_model_id: Llama3.1-405B-Instruct:bf16-mp8
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B-Instruct
|
||||
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
|
||||
provider_model_id: Llama3.1-405B-Instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B-Instruct:bf16-mp16
|
||||
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
|
||||
provider_model_id: Llama3.1-405B-Instruct:bf16-mp16
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.2-1B
|
||||
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
|
||||
provider_model_id: Llama3.2-1B
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.2-3B
|
||||
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
|
||||
provider_model_id: Llama3.2-3B
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/gpt-4o
|
||||
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
|
||||
|
|
@ -894,7 +1173,9 @@ models:
|
|||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}
|
||||
model_type: embedding
|
||||
shields: []
|
||||
shields:
|
||||
- shield_id: ${env.SAFETY_MODEL:=__disabled__}
|
||||
provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__}
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
|
|
|
|||
|
|
@ -31,6 +31,15 @@ from llama_stack.providers.registry.inference import available_providers
|
|||
from llama_stack.providers.remote.inference.anthropic.models import (
|
||||
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.bedrock.models import (
|
||||
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.cerebras.models import (
|
||||
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.databricks.databricks import (
|
||||
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.fireworks.models import (
|
||||
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
|
||||
)
|
||||
|
|
@ -40,9 +49,15 @@ from llama_stack.providers.remote.inference.gemini.models import (
|
|||
from llama_stack.providers.remote.inference.groq.models import (
|
||||
MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.nvidia.models import (
|
||||
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.openai.models import (
|
||||
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.runpod.runpod import (
|
||||
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.sambanova.models import (
|
||||
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
|
||||
)
|
||||
|
|
@ -59,6 +74,7 @@ from llama_stack.templates.template import (
|
|||
DistributionTemplate,
|
||||
RunConfigSettings,
|
||||
get_model_registry,
|
||||
get_shield_registry,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -72,6 +88,11 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
|
|||
"gemini": GEMINI_MODEL_ENTRIES,
|
||||
"groq": GROQ_MODEL_ENTRIES,
|
||||
"sambanova": SAMBANOVA_MODEL_ENTRIES,
|
||||
"cerebras": CEREBRAS_MODEL_ENTRIES,
|
||||
"bedrock": BEDROCK_MODEL_ENTRIES,
|
||||
"databricks": DATABRICKS_MODEL_ENTRIES,
|
||||
"nvidia": NVIDIA_MODEL_ENTRIES,
|
||||
"runpod": RUNPOD_MODEL_ENTRIES,
|
||||
}
|
||||
|
||||
# Special handling for providers with dynamic model entries
|
||||
|
|
@ -81,6 +102,10 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
|
|||
provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=__disabled__}",
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}",
|
||||
model_type=ModelType.embedding,
|
||||
|
|
@ -100,6 +125,20 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
|
|||
return model_entries_map.get(provider_type, [])
|
||||
|
||||
|
||||
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
|
||||
"""Get model entries for a specific provider type."""
|
||||
safety_model_entries_map = {
|
||||
"ollama": [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
return safety_model_entries_map.get(provider_type, [])
|
||||
|
||||
|
||||
def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
|
||||
"""Get configuration for a provider using its adapter's config class."""
|
||||
config_class = instantiate_class_type(provider_spec.config_class)
|
||||
|
|
@ -155,6 +194,23 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
|
|||
return inference_providers, available_models
|
||||
|
||||
|
||||
# build a list of shields for all possible providers
|
||||
def get_safety_models_for_providers(providers: list[Provider]) -> dict[str, list[ProviderModelEntry]]:
|
||||
available_models = {}
|
||||
for provider in providers:
|
||||
provider_type = provider.provider_type.split("::")[1]
|
||||
safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
|
||||
if len(safety_model_entries) == 0:
|
||||
continue
|
||||
|
||||
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
|
||||
provider_id = f"${{env.{env_var}:=__disabled__}}"
|
||||
|
||||
available_models[provider_id] = safety_model_entries
|
||||
|
||||
return available_models
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
remote_inference_providers, available_models = get_remote_inference_providers()
|
||||
|
||||
|
|
@ -185,6 +241,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_id="${env.ENABLE_PGVECTOR:=__disabled__}",
|
||||
provider_type="remote::pgvector",
|
||||
config=PGVectorVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}",
|
||||
db="${env.PGVECTOR_DB:=}",
|
||||
user="${env.PGVECTOR_USER:=}",
|
||||
password="${env.PGVECTOR_PASSWORD:=}",
|
||||
|
|
@ -244,7 +301,10 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
},
|
||||
)
|
||||
|
||||
default_models = get_model_registry(available_models)
|
||||
default_models, ids_conflict_in_models = get_model_registry(available_models)
|
||||
|
||||
available_safety_models = get_safety_models_for_providers(remote_inference_providers)
|
||||
shields = get_shield_registry(available_safety_models, ids_conflict_in_models)
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
|
|
@ -266,9 +326,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
default_models=default_models + [embedding_model],
|
||||
default_tool_groups=default_tool_groups,
|
||||
# TODO: add a way to enable/disable shields on the fly
|
||||
# default_shields=[
|
||||
# ShieldInput(provider_id="llama-guard", shield_id="${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-8B}")
|
||||
# ],
|
||||
default_shields=shields,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as ge
|
|||
|
||||
def get_model_registry(
|
||||
available_models: dict[str, list[ProviderModelEntry]],
|
||||
) -> list[ModelInput]:
|
||||
) -> tuple[list[ModelInput], bool]:
|
||||
models = []
|
||||
|
||||
# check for conflicts in model ids
|
||||
|
|
@ -74,7 +74,50 @@ def get_model_registry(
|
|||
metadata=entry.metadata,
|
||||
)
|
||||
)
|
||||
return models
|
||||
return models, ids_conflict
|
||||
|
||||
|
||||
def get_shield_registry(
|
||||
available_safety_models: dict[str, list[ProviderModelEntry]],
|
||||
ids_conflict_in_models: bool,
|
||||
) -> list[ShieldInput]:
|
||||
shields = []
|
||||
|
||||
# check for conflicts in shield ids
|
||||
all_ids = set()
|
||||
ids_conflict = False
|
||||
|
||||
for _, entries in available_safety_models.items():
|
||||
for entry in entries:
|
||||
ids = [entry.provider_model_id] + entry.aliases
|
||||
for model_id in ids:
|
||||
if model_id in all_ids:
|
||||
ids_conflict = True
|
||||
rich.print(
|
||||
f"[yellow]Shield id {model_id} conflicts; all shield ids will be prefixed with provider id[/yellow]"
|
||||
)
|
||||
break
|
||||
all_ids.update(ids)
|
||||
if ids_conflict:
|
||||
break
|
||||
if ids_conflict:
|
||||
break
|
||||
|
||||
for provider_id, entries in available_safety_models.items():
|
||||
for entry in entries:
|
||||
ids = [entry.provider_model_id] + entry.aliases
|
||||
for model_id in ids:
|
||||
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
|
||||
shields.append(
|
||||
ShieldInput(
|
||||
shield_id=identifier,
|
||||
provider_shield_id=f"{provider_id}/{entry.provider_model_id}"
|
||||
if ids_conflict_in_models
|
||||
else entry.provider_model_id,
|
||||
)
|
||||
)
|
||||
|
||||
return shields
|
||||
|
||||
|
||||
class DefaultModel(BaseModel):
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
},
|
||||
)
|
||||
|
||||
default_models = get_model_registry(available_models)
|
||||
default_models, _ = get_model_registry(available_models)
|
||||
return DistributionTemplate(
|
||||
name="watsonx",
|
||||
distro_type="remote_hosted",
|
||||
|
|
|
|||
82
llama_stack/ui/app/logs/vector-stores/[id]/page.tsx
Normal file
82
llama_stack/ui/app/logs/vector-stores/[id]/page.tsx
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
"use client";
|
||||
|
||||
import { useEffect, useState } from "react";
|
||||
import { useParams, useRouter } from "next/navigation";
|
||||
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
|
||||
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
|
||||
import { VectorStoreDetailView } from "@/components/vector-stores/vector-store-detail";
|
||||
|
||||
export default function VectorStoreDetailPage() {
|
||||
const params = useParams();
|
||||
const id = params.id as string;
|
||||
const client = useAuthClient();
|
||||
const router = useRouter();
|
||||
|
||||
const [store, setStore] = useState<VectorStore | null>(null);
|
||||
const [files, setFiles] = useState<VectorStoreFile[]>([]);
|
||||
const [isLoadingStore, setIsLoadingStore] = useState(true);
|
||||
const [isLoadingFiles, setIsLoadingFiles] = useState(true);
|
||||
const [errorStore, setErrorStore] = useState<Error | null>(null);
|
||||
const [errorFiles, setErrorFiles] = useState<Error | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (!id) {
|
||||
setErrorStore(new Error("Vector Store ID is missing."));
|
||||
setIsLoadingStore(false);
|
||||
return;
|
||||
}
|
||||
const fetchStore = async () => {
|
||||
setIsLoadingStore(true);
|
||||
setErrorStore(null);
|
||||
try {
|
||||
const response = await client.vectorStores.retrieve(id);
|
||||
setStore(response as VectorStore);
|
||||
} catch (err) {
|
||||
setErrorStore(
|
||||
err instanceof Error
|
||||
? err
|
||||
: new Error("Failed to load vector store."),
|
||||
);
|
||||
} finally {
|
||||
setIsLoadingStore(false);
|
||||
}
|
||||
};
|
||||
fetchStore();
|
||||
}, [id, client]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!id) {
|
||||
setErrorFiles(new Error("Vector Store ID is missing."));
|
||||
setIsLoadingFiles(false);
|
||||
return;
|
||||
}
|
||||
const fetchFiles = async () => {
|
||||
setIsLoadingFiles(true);
|
||||
setErrorFiles(null);
|
||||
try {
|
||||
const result = await client.vectorStores.files.list(id as any);
|
||||
setFiles((result as any).data);
|
||||
} catch (err) {
|
||||
setErrorFiles(
|
||||
err instanceof Error ? err : new Error("Failed to load files."),
|
||||
);
|
||||
} finally {
|
||||
setIsLoadingFiles(false);
|
||||
}
|
||||
};
|
||||
fetchFiles();
|
||||
}, [id]);
|
||||
|
||||
return (
|
||||
<VectorStoreDetailView
|
||||
store={store}
|
||||
files={files}
|
||||
isLoadingStore={isLoadingStore}
|
||||
isLoadingFiles={isLoadingFiles}
|
||||
errorStore={errorStore}
|
||||
errorFiles={errorFiles}
|
||||
id={id}
|
||||
/>
|
||||
);
|
||||
}
|
||||
16
llama_stack/ui/app/logs/vector-stores/layout.tsx
Normal file
16
llama_stack/ui/app/logs/vector-stores/layout.tsx
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import LogsLayout from "@/components/layout/logs-layout";
|
||||
|
||||
export default function VectorStoresLayout({
|
||||
children,
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
return (
|
||||
<LogsLayout sectionLabel="Vector Stores" basePath="/logs/vector-stores">
|
||||
{children}
|
||||
</LogsLayout>
|
||||
);
|
||||
}
|
||||
121
llama_stack/ui/app/logs/vector-stores/page.tsx
Normal file
121
llama_stack/ui/app/logs/vector-stores/page.tsx
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||
import type {
|
||||
ListVectorStoresResponse,
|
||||
VectorStore,
|
||||
} from "llama-stack-client/resources/vector-stores/vector-stores";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { usePagination } from "@/hooks/use-pagination";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCaption,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
|
||||
export default function VectorStoresPage() {
|
||||
const client = useAuthClient();
|
||||
const router = useRouter();
|
||||
const {
|
||||
data: stores,
|
||||
status,
|
||||
hasMore,
|
||||
error,
|
||||
loadMore,
|
||||
} = usePagination<VectorStore>({
|
||||
limit: 20,
|
||||
order: "desc",
|
||||
fetchFunction: async (client, params) => {
|
||||
const response = await client.vectorStores.list({
|
||||
after: params.after,
|
||||
limit: params.limit,
|
||||
order: params.order,
|
||||
} as any);
|
||||
return response as ListVectorStoresResponse;
|
||||
},
|
||||
errorMessagePrefix: "vector stores",
|
||||
});
|
||||
|
||||
// Auto-load all pages for infinite scroll behavior (like Responses)
|
||||
React.useEffect(() => {
|
||||
if (status === "idle" && hasMore) {
|
||||
loadMore();
|
||||
}
|
||||
}, [status, hasMore, loadMore]);
|
||||
|
||||
if (status === "loading") {
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
<Skeleton className="h-8 w-full" />
|
||||
<Skeleton className="h-4 w-full" />
|
||||
<Skeleton className="h-4 w-full" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (status === "error") {
|
||||
return <div className="text-destructive">Error: {error?.message}</div>;
|
||||
}
|
||||
|
||||
if (!stores || stores.length === 0) {
|
||||
return <p>No vector stores found.</p>;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="overflow-auto flex-1 min-h-0">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>ID</TableHead>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>Created</TableHead>
|
||||
<TableHead>Completed</TableHead>
|
||||
<TableHead>Cancelled</TableHead>
|
||||
<TableHead>Failed</TableHead>
|
||||
<TableHead>In Progress</TableHead>
|
||||
<TableHead>Total</TableHead>
|
||||
<TableHead>Usage Bytes</TableHead>
|
||||
<TableHead>Provider ID</TableHead>
|
||||
<TableHead>Provider Vector DB ID</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{stores.map((store) => {
|
||||
const fileCounts = store.file_counts;
|
||||
const metadata = store.metadata || {};
|
||||
const providerId = metadata.provider_id ?? "";
|
||||
const providerDbId = metadata.provider_vector_db_id ?? "";
|
||||
|
||||
return (
|
||||
<TableRow
|
||||
key={store.id}
|
||||
onClick={() => router.push(`/logs/vector-stores/${store.id}`)}
|
||||
className="cursor-pointer hover:bg-muted/50"
|
||||
>
|
||||
<TableCell>{store.id}</TableCell>
|
||||
<TableCell>{store.name}</TableCell>
|
||||
<TableCell>
|
||||
{new Date(store.created_at * 1000).toLocaleString()}
|
||||
</TableCell>
|
||||
<TableCell>{fileCounts.completed}</TableCell>
|
||||
<TableCell>{fileCounts.cancelled}</TableCell>
|
||||
<TableCell>{fileCounts.failed}</TableCell>
|
||||
<TableCell>{fileCounts.in_progress}</TableCell>
|
||||
<TableCell>{fileCounts.total}</TableCell>
|
||||
<TableCell>{store.usage_bytes}</TableCell>
|
||||
<TableCell>{providerId}</TableCell>
|
||||
<TableCell>{providerDbId}</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,6 +1,11 @@
|
|||
"use client";
|
||||
|
||||
import { MessageSquareText, MessagesSquare, MoveUpRight } from "lucide-react";
|
||||
import {
|
||||
MessageSquareText,
|
||||
MessagesSquare,
|
||||
MoveUpRight,
|
||||
Database,
|
||||
} from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { usePathname } from "next/navigation";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
|
@ -28,6 +33,11 @@ const logItems = [
|
|||
url: "/logs/responses",
|
||||
icon: MessagesSquare,
|
||||
},
|
||||
{
|
||||
title: "Vector Stores",
|
||||
url: "/logs/vector-stores",
|
||||
icon: Database,
|
||||
},
|
||||
{
|
||||
title: "Documentation",
|
||||
url: "https://llama-stack.readthedocs.io/en/latest/references/api_reference/index.html",
|
||||
|
|
@ -57,13 +67,13 @@ export function AppSidebar() {
|
|||
className={cn(
|
||||
"justify-start",
|
||||
isActive &&
|
||||
"bg-gray-200 hover:bg-gray-200 text-primary hover:text-primary",
|
||||
"bg-gray-200 dark:bg-gray-700 hover:bg-gray-200 dark:hover:bg-gray-700 text-gray-900 dark:text-gray-100",
|
||||
)}
|
||||
>
|
||||
<Link href={item.url}>
|
||||
<item.icon
|
||||
className={cn(
|
||||
isActive && "text-primary",
|
||||
isActive && "text-gray-900 dark:text-gray-100",
|
||||
"mr-2 h-4 w-4",
|
||||
)}
|
||||
/>
|
||||
|
|
|
|||
|
|
@ -93,7 +93,9 @@ export function PropertyItem({
|
|||
>
|
||||
<strong>{label}:</strong>{" "}
|
||||
{typeof value === "string" || typeof value === "number" ? (
|
||||
<span className="text-gray-900 font-medium">{value}</span>
|
||||
<span className="text-gray-900 dark:text-gray-100 font-medium">
|
||||
{value}
|
||||
</span>
|
||||
) : (
|
||||
value
|
||||
)}
|
||||
|
|
@ -112,7 +114,9 @@ export function PropertiesCard({ children }: PropertiesCardProps) {
|
|||
<CardTitle>Properties</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<ul className="space-y-2 text-sm text-gray-600">{children}</ul>
|
||||
<ul className="space-y-2 text-sm text-gray-600 dark:text-gray-400">
|
||||
{children}
|
||||
</ul>
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -17,10 +17,10 @@ export const MessageBlock: React.FC<MessageBlockProps> = ({
|
|||
}) => {
|
||||
return (
|
||||
<div className={`mb-4 ${className}`}>
|
||||
<p className="py-1 font-semibold text-gray-800 mb-1">
|
||||
<p className="py-1 font-semibold text-muted-foreground mb-1">
|
||||
{label}
|
||||
{labelDetail && (
|
||||
<span className="text-xs text-gray-500 font-normal ml-1">
|
||||
<span className="text-xs text-muted-foreground font-normal ml-1">
|
||||
{labelDetail}
|
||||
</span>
|
||||
)}
|
||||
|
|
|
|||
128
llama_stack/ui/components/vector-stores/vector-store-detail.tsx
Normal file
128
llama_stack/ui/components/vector-stores/vector-store-detail.tsx
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
"use client";
|
||||
|
||||
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
|
||||
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import {
|
||||
DetailLoadingView,
|
||||
DetailErrorView,
|
||||
DetailNotFoundView,
|
||||
DetailLayout,
|
||||
PropertiesCard,
|
||||
PropertyItem,
|
||||
} from "@/components/layout/detail-layout";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCaption,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
|
||||
interface VectorStoreDetailViewProps {
|
||||
store: VectorStore | null;
|
||||
files: VectorStoreFile[];
|
||||
isLoadingStore: boolean;
|
||||
isLoadingFiles: boolean;
|
||||
errorStore: Error | null;
|
||||
errorFiles: Error | null;
|
||||
id: string;
|
||||
}
|
||||
|
||||
export function VectorStoreDetailView({
|
||||
store,
|
||||
files,
|
||||
isLoadingStore,
|
||||
isLoadingFiles,
|
||||
errorStore,
|
||||
errorFiles,
|
||||
id,
|
||||
}: VectorStoreDetailViewProps) {
|
||||
const title = "Vector Store Details";
|
||||
|
||||
if (errorStore) {
|
||||
return <DetailErrorView title={title} id={id} error={errorStore} />;
|
||||
}
|
||||
if (isLoadingStore) {
|
||||
return <DetailLoadingView title={title} />;
|
||||
}
|
||||
if (!store) {
|
||||
return <DetailNotFoundView title={title} id={id} />;
|
||||
}
|
||||
|
||||
const mainContent = (
|
||||
<>
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>Files</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
{isLoadingFiles ? (
|
||||
<Skeleton className="h-4 w-full" />
|
||||
) : errorFiles ? (
|
||||
<div className="text-destructive text-sm">
|
||||
Error loading files: {errorFiles.message}
|
||||
</div>
|
||||
) : files.length > 0 ? (
|
||||
<Table>
|
||||
<TableCaption>Files in this vector store</TableCaption>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>ID</TableHead>
|
||||
<TableHead>Status</TableHead>
|
||||
<TableHead>Created</TableHead>
|
||||
<TableHead>Usage Bytes</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{files.map((file) => (
|
||||
<TableRow key={file.id}>
|
||||
<TableCell>{file.id}</TableCell>
|
||||
<TableCell>{file.status}</TableCell>
|
||||
<TableCell>
|
||||
{new Date(file.created_at * 1000).toLocaleString()}
|
||||
</TableCell>
|
||||
<TableCell>{file.usage_bytes}</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
) : (
|
||||
<p className="text-gray-500 italic text-sm">
|
||||
No files in this vector store.
|
||||
</p>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</>
|
||||
);
|
||||
|
||||
const sidebar = (
|
||||
<PropertiesCard>
|
||||
<PropertyItem label="ID" value={store.id} />
|
||||
<PropertyItem label="Name" value={store.name || ""} />
|
||||
<PropertyItem
|
||||
label="Created"
|
||||
value={new Date(store.created_at * 1000).toLocaleString()}
|
||||
/>
|
||||
<PropertyItem label="Status" value={store.status} />
|
||||
<PropertyItem label="Total Files" value={store.file_counts.total} />
|
||||
<PropertyItem label="Usage Bytes" value={store.usage_bytes} />
|
||||
<PropertyItem
|
||||
label="Provider ID"
|
||||
value={(store.metadata.provider_id as string) || ""}
|
||||
/>
|
||||
<PropertyItem
|
||||
label="Provider DB ID"
|
||||
value={(store.metadata.provider_vector_db_id as string) || ""}
|
||||
/>
|
||||
</PropertiesCard>
|
||||
);
|
||||
|
||||
return (
|
||||
<DetailLayout title={title} mainContent={mainContent} sidebar={sidebar} />
|
||||
);
|
||||
}
|
||||
474
llama_stack/ui/package-lock.json
generated
474
llama_stack/ui/package-lock.json
generated
|
|
@ -15,7 +15,7 @@
|
|||
"@radix-ui/react-tooltip": "^1.2.6",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"llama-stack-client": "0.2.13",
|
||||
"llama-stack-client": "^0.2.14",
|
||||
"lucide-react": "^0.510.0",
|
||||
"next": "15.3.3",
|
||||
"next-auth": "^4.24.11",
|
||||
|
|
@ -676,406 +676,6 @@
|
|||
"tslib": "^2.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/aix-ppc64": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.25.5.tgz",
|
||||
"integrity": "sha512-9o3TMmpmftaCMepOdA5k/yDw8SfInyzWWTjYTFCX3kPSDJMROQTb8jg+h9Cnwnmm1vOzvxN7gIfB5V2ewpjtGA==",
|
||||
"cpu": [
|
||||
"ppc64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"aix"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/android-arm": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/android-arm/-/android-arm-0.25.5.tgz",
|
||||
"integrity": "sha512-AdJKSPeEHgi7/ZhuIPtcQKr5RQdo6OO2IL87JkianiMYMPbCtot9fxPbrMiBADOWWm3T2si9stAiVsGbTQFkbA==",
|
||||
"cpu": [
|
||||
"arm"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"android"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/android-arm64": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/android-arm64/-/android-arm64-0.25.5.tgz",
|
||||
"integrity": "sha512-VGzGhj4lJO+TVGV1v8ntCZWJktV7SGCs3Pn1GRWI1SBFtRALoomm8k5E9Pmwg3HOAal2VDc2F9+PM/rEY6oIDg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"android"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/android-x64": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/android-x64/-/android-x64-0.25.5.tgz",
|
||||
"integrity": "sha512-D2GyJT1kjvO//drbRT3Hib9XPwQeWd9vZoBJn+bu/lVsOZ13cqNdDeqIF/xQ5/VmWvMduP6AmXvylO/PIc2isw==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"android"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/darwin-arm64": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.25.5.tgz",
|
||||
"integrity": "sha512-GtaBgammVvdF7aPIgH2jxMDdivezgFu6iKpmT+48+F8Hhg5J/sfnDieg0aeG/jfSvkYQU2/pceFPDKlqZzwnfQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/darwin-x64": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/darwin-x64/-/darwin-x64-0.25.5.tgz",
|
||||
"integrity": "sha512-1iT4FVL0dJ76/q1wd7XDsXrSW+oLoquptvh4CLR4kITDtqi2e/xwXwdCVH8hVHU43wgJdsq7Gxuzcs6Iq/7bxQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/freebsd-arm64": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/freebsd-arm64/-/freebsd-arm64-0.25.5.tgz",
|
||||
"integrity": "sha512-nk4tGP3JThz4La38Uy/gzyXtpkPW8zSAmoUhK9xKKXdBCzKODMc2adkB2+8om9BDYugz+uGV7sLmpTYzvmz6Sw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"freebsd"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/freebsd-x64": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/freebsd-x64/-/freebsd-x64-0.25.5.tgz",
|
||||
"integrity": "sha512-PrikaNjiXdR2laW6OIjlbeuCPrPaAl0IwPIaRv+SMV8CiM8i2LqVUHFC1+8eORgWyY7yhQY+2U2fA55mBzReaw==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"freebsd"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/linux-arm": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/linux-arm/-/linux-arm-0.25.5.tgz",
|
||||
"integrity": "sha512-cPzojwW2okgh7ZlRpcBEtsX7WBuqbLrNXqLU89GxWbNt6uIg78ET82qifUy3W6OVww6ZWobWub5oqZOVtwolfw==",
|
||||
"cpu": [
|
||||
"arm"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/linux-arm64": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/linux-arm64/-/linux-arm64-0.25.5.tgz",
|
||||
"integrity": "sha512-Z9kfb1v6ZlGbWj8EJk9T6czVEjjq2ntSYLY2cw6pAZl4oKtfgQuS4HOq41M/BcoLPzrUbNd+R4BXFyH//nHxVg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/linux-ia32": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/linux-ia32/-/linux-ia32-0.25.5.tgz",
|
||||
"integrity": "sha512-sQ7l00M8bSv36GLV95BVAdhJ2QsIbCuCjh/uYrWiMQSUuV+LpXwIqhgJDcvMTj+VsQmqAHL2yYaasENvJ7CDKA==",
|
||||
"cpu": [
|
||||
"ia32"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/linux-loong64": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/linux-loong64/-/linux-loong64-0.25.5.tgz",
|
||||
"integrity": "sha512-0ur7ae16hDUC4OL5iEnDb0tZHDxYmuQyhKhsPBV8f99f6Z9KQM02g33f93rNH5A30agMS46u2HP6qTdEt6Q1kg==",
|
||||
"cpu": [
|
||||
"loong64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/linux-mips64el": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/linux-mips64el/-/linux-mips64el-0.25.5.tgz",
|
||||
"integrity": "sha512-kB/66P1OsHO5zLz0i6X0RxlQ+3cu0mkxS3TKFvkb5lin6uwZ/ttOkP3Z8lfR9mJOBk14ZwZ9182SIIWFGNmqmg==",
|
||||
"cpu": [
|
||||
"mips64el"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/linux-ppc64": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/linux-ppc64/-/linux-ppc64-0.25.5.tgz",
|
||||
"integrity": "sha512-UZCmJ7r9X2fe2D6jBmkLBMQetXPXIsZjQJCjgwpVDz+YMcS6oFR27alkgGv3Oqkv07bxdvw7fyB71/olceJhkQ==",
|
||||
"cpu": [
|
||||
"ppc64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/linux-riscv64": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/linux-riscv64/-/linux-riscv64-0.25.5.tgz",
|
||||
"integrity": "sha512-kTxwu4mLyeOlsVIFPfQo+fQJAV9mh24xL+y+Bm6ej067sYANjyEw1dNHmvoqxJUCMnkBdKpvOn0Ahql6+4VyeA==",
|
||||
"cpu": [
|
||||
"riscv64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/linux-s390x": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/linux-s390x/-/linux-s390x-0.25.5.tgz",
|
||||
"integrity": "sha512-K2dSKTKfmdh78uJ3NcWFiqyRrimfdinS5ErLSn3vluHNeHVnBAFWC8a4X5N+7FgVE1EjXS1QDZbpqZBjfrqMTQ==",
|
||||
"cpu": [
|
||||
"s390x"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/linux-x64": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.25.5.tgz",
|
||||
"integrity": "sha512-uhj8N2obKTE6pSZ+aMUbqq+1nXxNjZIIjCjGLfsWvVpy7gKCOL6rsY1MhRh9zLtUtAI7vpgLMK6DxjO8Qm9lJw==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/netbsd-arm64": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/netbsd-arm64/-/netbsd-arm64-0.25.5.tgz",
|
||||
"integrity": "sha512-pwHtMP9viAy1oHPvgxtOv+OkduK5ugofNTVDilIzBLpoWAM16r7b/mxBvfpuQDpRQFMfuVr5aLcn4yveGvBZvw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"netbsd"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/netbsd-x64": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/netbsd-x64/-/netbsd-x64-0.25.5.tgz",
|
||||
"integrity": "sha512-WOb5fKrvVTRMfWFNCroYWWklbnXH0Q5rZppjq0vQIdlsQKuw6mdSihwSo4RV/YdQ5UCKKvBy7/0ZZYLBZKIbwQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"netbsd"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/openbsd-arm64": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/openbsd-arm64/-/openbsd-arm64-0.25.5.tgz",
|
||||
"integrity": "sha512-7A208+uQKgTxHd0G0uqZO8UjK2R0DDb4fDmERtARjSHWxqMTye4Erz4zZafx7Di9Cv+lNHYuncAkiGFySoD+Mw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"openbsd"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/openbsd-x64": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/openbsd-x64/-/openbsd-x64-0.25.5.tgz",
|
||||
"integrity": "sha512-G4hE405ErTWraiZ8UiSoesH8DaCsMm0Cay4fsFWOOUcz8b8rC6uCvnagr+gnioEjWn0wC+o1/TAHt+It+MpIMg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"openbsd"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/sunos-x64": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/sunos-x64/-/sunos-x64-0.25.5.tgz",
|
||||
"integrity": "sha512-l+azKShMy7FxzY0Rj4RCt5VD/q8mG/e+mDivgspo+yL8zW7qEwctQ6YqKX34DTEleFAvCIUviCFX1SDZRSyMQA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"sunos"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/win32-arm64": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/win32-arm64/-/win32-arm64-0.25.5.tgz",
|
||||
"integrity": "sha512-O2S7SNZzdcFG7eFKgvwUEZ2VG9D/sn/eIiz8XRZ1Q/DO5a3s76Xv0mdBzVM5j5R639lXQmPmSo0iRpHqUUrsxw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/win32-ia32": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/win32-ia32/-/win32-ia32-0.25.5.tgz",
|
||||
"integrity": "sha512-onOJ02pqs9h1iMJ1PQphR+VZv8qBMQ77Klcsqv9CNW2w6yLqoURLcgERAIurY6QE63bbLuqgP9ATqajFLK5AMQ==",
|
||||
"cpu": [
|
||||
"ia32"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/win32-x64": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/win32-x64/-/win32-x64-0.25.5.tgz",
|
||||
"integrity": "sha512-TXv6YnJ8ZMVdX+SXWVBo/0p8LTcrUYngpWjvm91TMjjBQii7Oz11Lw5lbDV5Y0TzuhSJHwiH4hEtC1I42mMS0g==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@eslint-community/eslint-utils": {
|
||||
"version": "4.7.0",
|
||||
"resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.7.0.tgz",
|
||||
|
|
@ -5999,46 +5599,6 @@
|
|||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/esbuild": {
|
||||
"version": "0.25.5",
|
||||
"resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.25.5.tgz",
|
||||
"integrity": "sha512-P8OtKZRv/5J5hhz0cUAdu/cLuPIKXpQl1R9pZtvmHWQvrAUVd0UNIPT4IB4W3rNOqVO0rlqHmCIbSwxh/c9yUQ==",
|
||||
"hasInstallScript": true,
|
||||
"license": "MIT",
|
||||
"bin": {
|
||||
"esbuild": "bin/esbuild"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@esbuild/aix-ppc64": "0.25.5",
|
||||
"@esbuild/android-arm": "0.25.5",
|
||||
"@esbuild/android-arm64": "0.25.5",
|
||||
"@esbuild/android-x64": "0.25.5",
|
||||
"@esbuild/darwin-arm64": "0.25.5",
|
||||
"@esbuild/darwin-x64": "0.25.5",
|
||||
"@esbuild/freebsd-arm64": "0.25.5",
|
||||
"@esbuild/freebsd-x64": "0.25.5",
|
||||
"@esbuild/linux-arm": "0.25.5",
|
||||
"@esbuild/linux-arm64": "0.25.5",
|
||||
"@esbuild/linux-ia32": "0.25.5",
|
||||
"@esbuild/linux-loong64": "0.25.5",
|
||||
"@esbuild/linux-mips64el": "0.25.5",
|
||||
"@esbuild/linux-ppc64": "0.25.5",
|
||||
"@esbuild/linux-riscv64": "0.25.5",
|
||||
"@esbuild/linux-s390x": "0.25.5",
|
||||
"@esbuild/linux-x64": "0.25.5",
|
||||
"@esbuild/netbsd-arm64": "0.25.5",
|
||||
"@esbuild/netbsd-x64": "0.25.5",
|
||||
"@esbuild/openbsd-arm64": "0.25.5",
|
||||
"@esbuild/openbsd-x64": "0.25.5",
|
||||
"@esbuild/sunos-x64": "0.25.5",
|
||||
"@esbuild/win32-arm64": "0.25.5",
|
||||
"@esbuild/win32-ia32": "0.25.5",
|
||||
"@esbuild/win32-x64": "0.25.5"
|
||||
}
|
||||
},
|
||||
"node_modules/escalade": {
|
||||
"version": "3.2.0",
|
||||
"resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz",
|
||||
|
|
@ -6993,6 +6553,7 @@
|
|||
"version": "2.3.3",
|
||||
"resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz",
|
||||
"integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==",
|
||||
"dev": true,
|
||||
"hasInstallScript": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
|
|
@ -7154,6 +6715,7 @@
|
|||
"version": "4.10.0",
|
||||
"resolved": "https://registry.npmjs.org/get-tsconfig/-/get-tsconfig-4.10.0.tgz",
|
||||
"integrity": "sha512-kGzZ3LWWQcGIAmg6iWvXn0ei6WDtV26wzHRMwDSzmAbcXrTEXxHy6IehI6/4eT6VRKyMP1eF1VqwrVUmE/LR7A==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"resolve-pkg-maps": "^1.0.0"
|
||||
|
|
@ -9537,9 +9099,10 @@
|
|||
"license": "MIT"
|
||||
},
|
||||
"node_modules/llama-stack-client": {
|
||||
"version": "0.2.13",
|
||||
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.13.tgz",
|
||||
"integrity": "sha512-R1rTFLwgUimr+KjEUkzUvFL6vLASwS9qj3UDSVkJ5BmrKAs5GwVAMeL7yZaTBXGuPUVh124WSlC4d9H0FjWqLA==",
|
||||
"version": "0.2.14",
|
||||
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.14.tgz",
|
||||
"integrity": "sha512-bVU3JHp+EPEKR0Vb9vcd9ZyQj/72jSDuptKLwOXET9WrkphIQ8xuW5ueecMTgq8UEls3lwB3HiZM2cDOR9eDsQ==",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"@types/node": "^18.11.18",
|
||||
"@types/node-fetch": "^2.6.4",
|
||||
|
|
@ -9547,8 +9110,7 @@
|
|||
"agentkeepalive": "^4.2.1",
|
||||
"form-data-encoder": "1.7.2",
|
||||
"formdata-node": "^4.3.2",
|
||||
"node-fetch": "^2.6.7",
|
||||
"tsx": "^4.19.2"
|
||||
"node-fetch": "^2.6.7"
|
||||
}
|
||||
},
|
||||
"node_modules/llama-stack-client/node_modules/@types/node": {
|
||||
|
|
@ -11148,6 +10710,7 @@
|
|||
"version": "1.0.0",
|
||||
"resolved": "https://registry.npmjs.org/resolve-pkg-maps/-/resolve-pkg-maps-1.0.0.tgz",
|
||||
"integrity": "sha512-seS2Tj26TBVOC2NIc2rOe2y2ZO7efxITtLZcGSOnHHNOQ7CkiUBfw0Iw2ck6xkIhPwLhKNLS8BO+hEpngQlqzw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"url": "https://github.com/privatenumber/resolve-pkg-maps?sponsor=1"
|
||||
|
|
@ -12198,25 +11761,6 @@
|
|||
"integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==",
|
||||
"license": "0BSD"
|
||||
},
|
||||
"node_modules/tsx": {
|
||||
"version": "4.19.4",
|
||||
"resolved": "https://registry.npmjs.org/tsx/-/tsx-4.19.4.tgz",
|
||||
"integrity": "sha512-gK5GVzDkJK1SI1zwHf32Mqxf2tSJkNx+eYcNly5+nHvWqXUJYUkWBQtKauoESz3ymezAI++ZwT855x5p5eop+Q==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"esbuild": "~0.25.0",
|
||||
"get-tsconfig": "^4.7.5"
|
||||
},
|
||||
"bin": {
|
||||
"tsx": "dist/cli.mjs"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"fsevents": "~2.3.3"
|
||||
}
|
||||
},
|
||||
"node_modules/tw-animate-css": {
|
||||
"version": "1.2.9",
|
||||
"resolved": "https://registry.npmjs.org/tw-animate-css/-/tw-animate-css-1.2.9.tgz",
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@
|
|||
"@radix-ui/react-tooltip": "^1.2.6",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"llama-stack-client": "0.2.13",
|
||||
"llama-stack-client": "^0.2.15",
|
||||
"lucide-react": "^0.510.0",
|
||||
"next": "15.3.3",
|
||||
"next-auth": "^4.24.11",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue