mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Merge branch 'main' into use-openai-for-databricks
This commit is contained in:
commit
c8623607f5
31 changed files with 665 additions and 691 deletions
|
@ -17,11 +17,8 @@ export POSTGRES_PASSWORD=llamastack
|
|||
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||
|
||||
export MOCK_INFERENCE_MODEL=mock-inference
|
||||
|
||||
export MOCK_INFERENCE_URL=openai-mock-service:8080
|
||||
|
||||
export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
export LLAMA_STACK_WORKERS=4
|
||||
|
||||
set -euo pipefail
|
||||
set -x
|
||||
|
|
|
@ -5,6 +5,7 @@ data:
|
|||
image_name: kubernetes-benchmark-demo
|
||||
apis:
|
||||
- agents
|
||||
- files
|
||||
- inference
|
||||
- files
|
||||
- safety
|
||||
|
@ -23,6 +24,14 @@ data:
|
|||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||
vector_io:
|
||||
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
|
|
|
@ -52,9 +52,20 @@ spec:
|
|||
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
|
||||
- name: VLLM_TLS_VERIFY
|
||||
value: "false"
|
||||
command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"]
|
||||
- name: LLAMA_STACK_LOGGING
|
||||
value: "all=WARNING"
|
||||
- name: LLAMA_STACK_CONFIG
|
||||
value: "/etc/config/stack_run_config.yaml"
|
||||
- name: LLAMA_STACK_WORKERS
|
||||
value: "${LLAMA_STACK_WORKERS}"
|
||||
command: ["uvicorn", "llama_stack.core.server.server:create_app", "--host", "0.0.0.0", "--port", "8323", "--workers", "$LLAMA_STACK_WORKERS", "--factory"]
|
||||
ports:
|
||||
- containerPort: 8323
|
||||
resources:
|
||||
requests:
|
||||
cpu: "${LLAMA_STACK_WORKERS}"
|
||||
limits:
|
||||
cpu: "${LLAMA_STACK_WORKERS}"
|
||||
volumeMounts:
|
||||
- name: llama-storage
|
||||
mountPath: /root/.llama
|
||||
|
|
|
@ -460,10 +460,12 @@ client = LlamaStackClient(base_url="http://localhost:8321")
|
|||
embed_lm = next(m for m in client.models.list() if m.model_type == "embedding")
|
||||
embedding_model = embed_lm.identifier
|
||||
vector_db_id = f"v{uuid.uuid4().hex}"
|
||||
client.vector_dbs.register(
|
||||
# The VectorDB API is deprecated; the server now returns its own authoritative ID.
|
||||
# We capture the correct ID from the response's .identifier attribute.
|
||||
vector_db_id = client.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
).identifier
|
||||
|
||||
# Create Documents
|
||||
urls = [
|
||||
|
|
|
@ -23,7 +23,13 @@ To use Milvus in your Llama Stack project, follow these steps:
|
|||
|
||||
## Installation
|
||||
|
||||
You can install Milvus using pymilvus:
|
||||
If you want to use inline Milvus, you can install:
|
||||
|
||||
```bash
|
||||
pip install pymilvus[milvus-lite]
|
||||
```
|
||||
|
||||
If you want to use remote Milvus, you can install:
|
||||
|
||||
```bash
|
||||
pip install pymilvus
|
||||
|
|
|
@ -121,10 +121,6 @@ class AutoRoutedProviderSpec(ProviderSpec):
|
|||
default=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
|
||||
|
||||
|
||||
# Example: /models, /shields
|
||||
class RoutingTableProviderSpec(ProviderSpec):
|
||||
|
|
|
@ -16,11 +16,10 @@ from llama_stack.core.datatypes import BuildConfig, DistributionSpec
|
|||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
@ -77,27 +76,12 @@ def providable_apis() -> list[Api]:
|
|||
|
||||
|
||||
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
|
||||
adapter = AdapterSpec(**spec_data["adapter"])
|
||||
spec = remote_provider_spec(
|
||||
api=api,
|
||||
adapter=adapter,
|
||||
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
||||
)
|
||||
spec = RemoteProviderSpec(api=api, provider_type=f"remote::{spec_data['adapter_type']}", **spec_data)
|
||||
return spec
|
||||
|
||||
|
||||
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
||||
spec = InlineProviderSpec(
|
||||
api=api,
|
||||
provider_type=f"inline::{provider_name}",
|
||||
pip_packages=spec_data.get("pip_packages", []),
|
||||
module=spec_data["module"],
|
||||
config_class=spec_data["config_class"],
|
||||
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
||||
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
|
||||
provider_data_validator=spec_data.get("provider_data_validator"),
|
||||
container_image=spec_data.get("container_image"),
|
||||
)
|
||||
spec = InlineProviderSpec(api=api, provider_type=f"inline::{provider_name}", **spec_data)
|
||||
return spec
|
||||
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ from llama_stack.core.request_headers import (
|
|||
from llama_stack.core.resolver import ProviderRegistry
|
||||
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
|
||||
from llama_stack.core.stack import (
|
||||
construct_stack,
|
||||
Stack,
|
||||
get_stack_run_config_from_distro,
|
||||
replace_env_vars,
|
||||
)
|
||||
|
@ -252,7 +252,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
try:
|
||||
self.route_impls = None
|
||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
||||
|
||||
stack = Stack(self.config, self.custom_provider_registry)
|
||||
await stack.initialize()
|
||||
self.impls = stack.impls
|
||||
except ModuleNotFoundError as _e:
|
||||
cprint(_e.msg, color="red", file=sys.stderr)
|
||||
cprint(
|
||||
|
@ -289,6 +292,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
raise _e
|
||||
|
||||
assert self.impls is not None
|
||||
if Api.telemetry in self.impls:
|
||||
setup_logger(self.impls[Api.telemetry])
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import argparse
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
|
@ -50,17 +51,15 @@ from llama_stack.core.request_headers import (
|
|||
request_provider_data_context,
|
||||
user_from_scope,
|
||||
)
|
||||
from llama_stack.core.resolver import InvalidProviderError
|
||||
from llama_stack.core.server.routes import (
|
||||
find_matching_route,
|
||||
get_all_api_routes,
|
||||
initialize_route_impls,
|
||||
)
|
||||
from llama_stack.core.stack import (
|
||||
Stack,
|
||||
cast_image_name_to_string,
|
||||
construct_stack,
|
||||
replace_env_vars,
|
||||
shutdown_stack,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
|
@ -156,21 +155,34 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
|||
)
|
||||
|
||||
|
||||
async def shutdown(app):
|
||||
"""Initiate a graceful shutdown of the application.
|
||||
|
||||
Handled by the lifespan context manager. The shutdown process involves
|
||||
shutting down all implementations registered in the application.
|
||||
class StackApp(FastAPI):
|
||||
"""
|
||||
await shutdown_stack(app.__llama_stack_impls__)
|
||||
A wrapper around the FastAPI application to hold a reference to the Stack instance so that we can
|
||||
start background tasks (e.g. refresh model registry periodically) from the lifespan context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, config: StackRunConfig, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.stack: Stack = Stack(config)
|
||||
|
||||
# This code is called from a running event loop managed by uvicorn so we cannot simply call
|
||||
# asyncio.run() to initialize the stack. We cannot await either since this is not an async
|
||||
# function.
|
||||
# As a workaround, we use a thread pool executor to run the initialize() method
|
||||
# in a separate thread.
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(asyncio.run, self.stack.initialize())
|
||||
future.result()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
async def lifespan(app: StackApp):
|
||||
logger.info("Starting up")
|
||||
assert app.stack is not None
|
||||
app.stack.create_registry_refresh_task()
|
||||
yield
|
||||
logger.info("Shutting down")
|
||||
await shutdown(app)
|
||||
await app.stack.shutdown()
|
||||
|
||||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
|
@ -386,73 +398,61 @@ class ClientVersionMiddleware:
|
|||
return await self.app(scope, receive, send)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace | None = None):
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
def create_app(
|
||||
config_file: str | None = None,
|
||||
env_vars: list[str] | None = None,
|
||||
) -> StackApp:
|
||||
"""Create and configure the FastAPI application.
|
||||
|
||||
add_config_distro_args(parser)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||
help="Port to listen on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
||||
)
|
||||
Args:
|
||||
config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
|
||||
env_vars: List of environment variables in KEY=value format.
|
||||
disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
|
||||
|
||||
# Determine whether the server args are being passed by the "run" command, if this is the case
|
||||
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
||||
# parsed from the command line
|
||||
if args is None:
|
||||
args = parser.parse_args()
|
||||
Returns:
|
||||
Configured StackApp instance.
|
||||
"""
|
||||
config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
|
||||
if config_file is None:
|
||||
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
|
||||
|
||||
config_or_distro = get_config_from_args(args)
|
||||
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
||||
config_file = resolve_config_or_distro(config_file, Mode.RUN)
|
||||
|
||||
# Load and process configuration
|
||||
logger_config = None
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
||||
if args.env:
|
||||
for env_pair in args.env:
|
||||
|
||||
if env_vars:
|
||||
for env_pair in env_vars:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
||||
logger.info(f"Setting environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
raise ValueError(f"Invalid environment variable format: {env_pair}") from e
|
||||
|
||||
config = replace_env_vars(config_contents)
|
||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||
|
||||
_log_run_config(run_config=config)
|
||||
|
||||
app = FastAPI(
|
||||
app = StackApp(
|
||||
lifespan=lifespan,
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json",
|
||||
config=config,
|
||||
)
|
||||
|
||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
try:
|
||||
# Create and set the event loop that will be used for both construction and server runtime
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Construct the stack in the persistent event loop
|
||||
impls = loop.run_until_complete(construct_stack(config))
|
||||
|
||||
except InvalidProviderError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
impls = app.stack.impls
|
||||
|
||||
if config.server.auth:
|
||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
||||
|
@ -553,9 +553,54 @@ def main(args: argparse.Namespace | None = None):
|
|||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def main(args: argparse.Namespace | None = None):
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
|
||||
add_config_distro_args(parser)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||
help="Port to listen on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
||||
)
|
||||
|
||||
# Determine whether the server args are being passed by the "run" command, if this is the case
|
||||
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
||||
# parsed from the command line
|
||||
if args is None:
|
||||
args = parser.parse_args()
|
||||
|
||||
config_or_distro = get_config_from_args(args)
|
||||
|
||||
try:
|
||||
app = create_app(
|
||||
config_file=config_or_distro,
|
||||
env_vars=args.env,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating app: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
else:
|
||||
logger_config = None
|
||||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||
|
||||
import uvicorn
|
||||
|
||||
# Configure SSL if certificates are provided
|
||||
|
@ -593,7 +638,6 @@ def main(args: argparse.Namespace | None = None):
|
|||
if ssl_config:
|
||||
uvicorn_config.update(ssl_config)
|
||||
|
||||
# Run uvicorn in the existing event loop to preserve background tasks
|
||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||
|
@ -604,13 +648,9 @@ def main(args: argparse.Namespace | None = None):
|
|||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||
# signal handling but this is quite intrusive and not worth the effort.
|
||||
try:
|
||||
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||
finally:
|
||||
if not loop.is_closed():
|
||||
logger.debug("Closing event loop")
|
||||
loop.close()
|
||||
|
||||
|
||||
def _log_run_config(run_config: StackRunConfig):
|
||||
|
|
|
@ -315,11 +315,15 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
|||
impls[Api.prompts] = prompts_impl
|
||||
|
||||
|
||||
class Stack:
|
||||
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
||||
self.run_config = run_config
|
||||
self.provider_registry = provider_registry
|
||||
self.impls = None
|
||||
|
||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
# asked for in the run config.
|
||||
async def construct_stack(
|
||||
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
||||
) -> dict[Api, Any]:
|
||||
async def initialize(self):
|
||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
||||
from llama_stack.testing.inference_recorder import setup_inference_recording
|
||||
|
||||
|
@ -329,24 +333,28 @@ async def construct_stack(
|
|||
TEST_RECORDING_CONTEXT.__enter__()
|
||||
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||
|
||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||
policy = run_config.server.auth.access_policy if run_config.server.auth else []
|
||||
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
|
||||
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
|
||||
impls = await resolve_impls(
|
||||
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy
|
||||
self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
|
||||
)
|
||||
|
||||
# Add internal implementations after all other providers are resolved
|
||||
add_internal_implementations(impls, run_config)
|
||||
add_internal_implementations(impls, self.run_config)
|
||||
|
||||
if Api.prompts in impls:
|
||||
await impls[Api.prompts].initialize()
|
||||
|
||||
await register_resources(run_config, impls)
|
||||
await register_resources(self.run_config, impls)
|
||||
|
||||
await refresh_registry_once(impls)
|
||||
self.impls = impls
|
||||
|
||||
def create_registry_refresh_task(self):
|
||||
assert self.impls is not None, "Must call initialize() before starting"
|
||||
|
||||
global REGISTRY_REFRESH_TASK
|
||||
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
|
||||
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls))
|
||||
|
||||
def cb(task):
|
||||
import traceback
|
||||
|
@ -360,11 +368,9 @@ async def construct_stack(
|
|||
logger.debug("Model refresh task completed")
|
||||
|
||||
REGISTRY_REFRESH_TASK.add_done_callback(cb)
|
||||
return impls
|
||||
|
||||
|
||||
async def shutdown_stack(impls: dict[Api, Any]):
|
||||
for impl in impls.values():
|
||||
async def shutdown(self):
|
||||
for impl in self.impls.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info(f"Shutting down {impl_name}")
|
||||
try:
|
||||
|
|
|
@ -78,12 +78,12 @@ def get_remote_inference_providers() -> list[Provider]:
|
|||
remote_providers = [
|
||||
provider
|
||||
for provider in available_providers()
|
||||
if isinstance(provider, RemoteProviderSpec) and provider.adapter.adapter_type in ENABLED_INFERENCE_PROVIDERS
|
||||
if isinstance(provider, RemoteProviderSpec) and provider.adapter_type in ENABLED_INFERENCE_PROVIDERS
|
||||
]
|
||||
|
||||
inference_providers = []
|
||||
for provider_spec in remote_providers:
|
||||
provider_type = provider_spec.adapter.adapter_type
|
||||
provider_type = provider_spec.adapter_type
|
||||
|
||||
if provider_type in INFERENCE_PROVIDER_IDS:
|
||||
provider_id = INFERENCE_PROVIDER_IDS[provider_type]
|
||||
|
|
|
@ -131,6 +131,15 @@ class ProviderSpec(BaseModel):
|
|||
""",
|
||||
)
|
||||
|
||||
pip_packages: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
|
||||
provider_data_validator: str | None = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
|
||||
|
||||
# used internally by the resolver; this is a hack for now
|
||||
|
@ -145,45 +154,8 @@ class RoutingTable(Protocol):
|
|||
async def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||
|
||||
|
||||
# TODO: this can now be inlined into RemoteProviderSpec
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
adapter_type: str = Field(
|
||||
...,
|
||||
description="Unique identifier for this adapter",
|
||||
)
|
||||
module: str = Field(
|
||||
default_factory=str,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
""",
|
||||
)
|
||||
pip_packages: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
config_class: str = Field(
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
provider_data_validator: str | None = Field(
|
||||
default=None,
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
A description of the provider. This is used to display in the documentation.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InlineProviderSpec(ProviderSpec):
|
||||
pip_packages: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
container_image: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
|
@ -191,10 +163,6 @@ The container image to use for this implementation. If one is provided, pip_pack
|
|||
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
|
||||
""",
|
||||
)
|
||||
# module field is inherited from ProviderSpec
|
||||
provider_data_validator: str | None = Field(
|
||||
default=None,
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
|
@ -223,10 +191,15 @@ class RemoteProviderConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class RemoteProviderSpec(ProviderSpec):
|
||||
adapter: AdapterSpec = Field(
|
||||
adapter_type: str = Field(
|
||||
...,
|
||||
description="Unique identifier for this adapter",
|
||||
)
|
||||
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||
API responses, specify the adapter here.
|
||||
A description of the provider. This is used to display in the documentation.
|
||||
""",
|
||||
)
|
||||
|
||||
|
@ -234,33 +207,6 @@ API responses, specify the adapter here.
|
|||
def container_image(self) -> str | None:
|
||||
return None
|
||||
|
||||
# module field is inherited from ProviderSpec
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return self.adapter.pip_packages
|
||||
|
||||
@property
|
||||
def provider_data_validator(self) -> str | None:
|
||||
return self.adapter.provider_data_validator
|
||||
|
||||
|
||||
def remote_provider_spec(
|
||||
api: Api,
|
||||
adapter: AdapterSpec,
|
||||
api_dependencies: list[Api] | None = None,
|
||||
optional_api_dependencies: list[Api] | None = None,
|
||||
) -> RemoteProviderSpec:
|
||||
return RemoteProviderSpec(
|
||||
api=api,
|
||||
provider_type=f"remote::{adapter.adapter_type}",
|
||||
config_class=adapter.config_class,
|
||||
module=adapter.module,
|
||||
adapter=adapter,
|
||||
api_dependencies=api_dependencies or [],
|
||||
optional_api_dependencies=optional_api_dependencies or [],
|
||||
)
|
||||
|
||||
|
||||
class HealthStatus(StrEnum):
|
||||
OK = "OK"
|
||||
|
|
|
@ -44,7 +44,7 @@ class LocalfsFilesImpl(Files):
|
|||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize SQL store for metadata
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store))
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
|
||||
await self.sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
|
@ -74,7 +74,7 @@ class LocalfsFilesImpl(Files):
|
|||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
|
||||
|
||||
|
@ -150,7 +150,6 @@ class LocalfsFilesImpl(Files):
|
|||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
policy=self.policy,
|
||||
where=where_conditions if where_conditions else None,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
|
||||
|
@ -25,10 +24,10 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api_dependencies=[],
|
||||
description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.datasetio,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="huggingface",
|
||||
provider_type="remote::huggingface",
|
||||
pip_packages=[
|
||||
"datasets>=4.0.0",
|
||||
],
|
||||
|
@ -36,17 +35,15 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
|
||||
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.datasetio,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
module="llama_stack.providers.remote.datasetio.nvidia",
|
||||
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
|
||||
pip_packages=[
|
||||
"datasets>=4.0.0",
|
||||
],
|
||||
module="llama_stack.providers.remote.datasetio.nvidia",
|
||||
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
|
||||
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
|
@ -25,17 +25,16 @@ def available_providers() -> list[ProviderSpec]:
|
|||
],
|
||||
description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.eval,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=[
|
||||
"requests",
|
||||
],
|
||||
provider_type="remote::nvidia",
|
||||
module="llama_stack.providers.remote.eval.nvidia",
|
||||
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
|
||||
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
|
||||
),
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
|
|
|
@ -4,13 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
|
||||
|
||||
|
||||
|
@ -25,14 +19,13 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
|
||||
description="Local filesystem-based file storage provider for managing files and documents locally.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.files,
|
||||
adapter=AdapterSpec(
|
||||
provider_type="remote::s3",
|
||||
adapter_type="s3",
|
||||
pip_packages=["boto3"] + sql_store_pip_packages,
|
||||
module="llama_stack.providers.remote.files.s3",
|
||||
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
|
||||
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
META_REFERENCE_DEPS = [
|
||||
|
@ -49,10 +48,10 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
||||
description="Sentence Transformers inference provider for text embeddings and similarity search.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="cerebras",
|
||||
provider_type="remote::cerebras",
|
||||
pip_packages=[
|
||||
"cerebras_cloud_sdk",
|
||||
],
|
||||
|
@ -60,62 +59,56 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
||||
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="ollama",
|
||||
provider_type="remote::ollama",
|
||||
pip_packages=["ollama", "aiohttp", "h11>=0.16.0"],
|
||||
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
|
||||
module="llama_stack.providers.remote.inference.ollama",
|
||||
description="Ollama inference provider for running local models through the Ollama runtime.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="vllm",
|
||||
provider_type="remote::vllm",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.vllm",
|
||||
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
|
||||
description="Remote vLLM inference provider for connecting to vLLM servers.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="tgi",
|
||||
provider_type="remote::tgi",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig",
|
||||
description="Text Generation Inference (TGI) provider for HuggingFace model serving.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="hf::serverless",
|
||||
provider_type="remote::hf::serverless",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig",
|
||||
description="HuggingFace Inference API serverless provider for on-demand model inference.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
provider_type="remote::hf::endpoint",
|
||||
adapter_type="hf::endpoint",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig",
|
||||
description="HuggingFace Inference Endpoints provider for dedicated model serving.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="fireworks",
|
||||
provider_type="remote::fireworks",
|
||||
pip_packages=[
|
||||
"fireworks-ai<=0.17.16",
|
||||
],
|
||||
|
@ -124,11 +117,10 @@ def available_providers() -> list[ProviderSpec]:
|
|||
provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
|
||||
description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="together",
|
||||
provider_type="remote::together",
|
||||
pip_packages=[
|
||||
"together",
|
||||
],
|
||||
|
@ -137,85 +129,82 @@ def available_providers() -> list[ProviderSpec]:
|
|||
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
||||
description="Together AI inference provider for open-source models and collaborative AI development.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="bedrock",
|
||||
provider_type="remote::bedrock",
|
||||
pip_packages=["boto3"],
|
||||
module="llama_stack.providers.remote.inference.bedrock",
|
||||
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
|
||||
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="databricks",
|
||||
provider_type="remote::databricks",
|
||||
pip_packages=["databricks-sdk"],
|
||||
module="llama_stack.providers.remote.inference.databricks",
|
||||
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
||||
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.nvidia",
|
||||
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
||||
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="runpod",
|
||||
provider_type="remote::runpod",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.runpod",
|
||||
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
|
||||
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="openai",
|
||||
provider_type="remote::openai",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.openai",
|
||||
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
|
||||
description="OpenAI inference provider for accessing GPT models and other OpenAI services.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="anthropic",
|
||||
provider_type="remote::anthropic",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.anthropic",
|
||||
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
|
||||
description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="gemini",
|
||||
pip_packages=["litellm"],
|
||||
provider_type="remote::gemini",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.gemini",
|
||||
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
|
||||
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="vertexai",
|
||||
pip_packages=["litellm", "google-cloud-aiplatform"],
|
||||
provider_type="remote::vertexai",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
"google-cloud-aiplatform",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.vertexai",
|
||||
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
|
||||
|
@ -240,65 +229,63 @@ Available Models:
|
|||
- vertex_ai/gemini-2.5-flash
|
||||
- vertex_ai/gemini-2.5-pro""",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="groq",
|
||||
pip_packages=["litellm"],
|
||||
provider_type="remote::groq",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.groq",
|
||||
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="llama-openai-compat",
|
||||
provider_type="remote::llama-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.llama_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
|
||||
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sambanova",
|
||||
pip_packages=["litellm"],
|
||||
provider_type="remote::sambanova",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.sambanova",
|
||||
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
|
||||
description="SambaNova inference provider for running models on SambaNova's dataflow architecture.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="passthrough",
|
||||
provider_type="remote::passthrough",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.passthrough",
|
||||
config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
|
||||
description="Passthrough inference provider for connecting to any external inference service not directly supported.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="watsonx",
|
||||
provider_type="remote::watsonx",
|
||||
pip_packages=["ibm_watsonx_ai"],
|
||||
module="llama_stack.providers.remote.inference.watsonx",
|
||||
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
provider_type="remote::azure",
|
||||
adapter_type="azure",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.azure",
|
||||
|
@ -310,5 +297,4 @@ Provider documentation
|
|||
https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
|
||||
""",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
from typing import cast
|
||||
|
||||
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||
|
||||
# We provide two versions of these providers so that distributions can package the appropriate version of torch.
|
||||
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
|
||||
|
@ -57,14 +57,13 @@ def available_providers() -> list[ProviderSpec]:
|
|||
],
|
||||
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.post_training,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
pip_packages=["requests", "aiohttp"],
|
||||
module="llama_stack.providers.remote.post_training.nvidia",
|
||||
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
|
||||
description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
|
||||
|
@ -48,35 +47,32 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
|
||||
description="Code Scanner safety provider for detecting security vulnerabilities and unsafe code patterns.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="bedrock",
|
||||
provider_type="remote::bedrock",
|
||||
pip_packages=["boto3"],
|
||||
module="llama_stack.providers.remote.safety.bedrock",
|
||||
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
|
||||
description="AWS Bedrock safety provider for content moderation using AWS's safety services.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
pip_packages=["requests"],
|
||||
module="llama_stack.providers.remote.safety.nvidia",
|
||||
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
|
||||
description="NVIDIA's safety provider for content moderation and safety filtering.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sambanova",
|
||||
provider_type="remote::sambanova",
|
||||
pip_packages=["litellm", "requests"],
|
||||
module="llama_stack.providers.remote.safety.sambanova",
|
||||
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
|
||||
description="SambaNova's safety provider for content moderation and safety filtering.",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
|
||||
|
@ -35,59 +34,54 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api_dependencies=[Api.vector_io, Api.inference, Api.files],
|
||||
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="brave-search",
|
||||
provider_type="remote::brave-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.brave_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
||||
description="Brave Search tool for web search capabilities with privacy-focused results.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="bing-search",
|
||||
provider_type="remote::bing-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.bing_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
|
||||
description="Bing Search tool for web search capabilities using Microsoft's search engine.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="tavily-search",
|
||||
provider_type="remote::tavily-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.tavily_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
|
||||
description="Tavily Search tool for AI-optimized web search with structured results.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="wolfram-alpha",
|
||||
provider_type="remote::wolfram-alpha",
|
||||
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
|
||||
description="Wolfram Alpha tool for computational knowledge and mathematical calculations.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="model-context-protocol",
|
||||
provider_type="remote::model-context-protocol",
|
||||
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
|
||||
pip_packages=["mcp>=1.8.1"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
|
||||
description="Model Context Protocol (MCP) tool for standardized tool calling and context management.",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
|
||||
|
@ -300,13 +299,15 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
|
|||
Please refer to the sqlite-vec provider documentation.
|
||||
""",
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter_type="chromadb",
|
||||
provider_type="remote::chromadb",
|
||||
pip_packages=["chromadb-client"],
|
||||
module="llama_stack.providers.remote.vector_io.chroma",
|
||||
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
description="""
|
||||
[Chroma](https://www.trychroma.com/) is an inline and remote vector
|
||||
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
|
||||
|
@ -341,9 +342,6 @@ pip install chromadb
|
|||
See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general.
|
||||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::chromadb",
|
||||
|
@ -387,13 +385,15 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
|||
|
||||
""",
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter_type="pgvector",
|
||||
provider_type="remote::pgvector",
|
||||
pip_packages=["psycopg2-binary"],
|
||||
module="llama_stack.providers.remote.vector_io.pgvector",
|
||||
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
description="""
|
||||
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly in memory.
|
||||
|
@ -496,17 +496,16 @@ docker pull pgvector/pgvector:pg17
|
|||
See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general.
|
||||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter_type="weaviate",
|
||||
provider_type="remote::weaviate",
|
||||
pip_packages=["weaviate-client"],
|
||||
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
description="""
|
||||
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
|
||||
It allows you to store and query vectors directly within a Weaviate database.
|
||||
|
@ -539,9 +538,6 @@ To install Weaviate see the [Weaviate quickstart documentation](https://weaviate
|
|||
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.
|
||||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::qdrant",
|
||||
|
@ -594,27 +590,28 @@ docker pull qdrant/qdrant
|
|||
See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general.
|
||||
""",
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter_type="qdrant",
|
||||
provider_type="remote::qdrant",
|
||||
pip_packages=["qdrant-client"],
|
||||
module="llama_stack.providers.remote.vector_io.qdrant",
|
||||
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
description="""
|
||||
Please refer to the inline provider documentation.
|
||||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter_type="milvus",
|
||||
provider_type="remote::milvus",
|
||||
pip_packages=["pymilvus>=2.4.10"],
|
||||
module="llama_stack.providers.remote.vector_io.milvus",
|
||||
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
description="""
|
||||
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly within a Milvus database.
|
||||
|
@ -636,7 +633,13 @@ To use Milvus in your Llama Stack project, follow these steps:
|
|||
|
||||
## Installation
|
||||
|
||||
You can install Milvus using pymilvus:
|
||||
If you want to use inline Milvus, you can install:
|
||||
|
||||
```bash
|
||||
pip install pymilvus[milvus-lite]
|
||||
```
|
||||
|
||||
If you want to use remote Milvus, you can install:
|
||||
|
||||
```bash
|
||||
pip install pymilvus
|
||||
|
@ -807,13 +810,10 @@ See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for m
|
|||
For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md).
|
||||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::milvus",
|
||||
pip_packages=["pymilvus>=2.4.10"],
|
||||
pip_packages=["pymilvus[milvus-lite]>=2.4.10"],
|
||||
module="llama_stack.providers.inline.vector_io.milvus",
|
||||
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
|
@ -137,7 +137,7 @@ class S3FilesImpl(Files):
|
|||
where: dict[str, str | dict] = {"id": file_id}
|
||||
if not return_expired:
|
||||
where["expires_at"] = {">": self._now()}
|
||||
if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)):
|
||||
if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
return row
|
||||
|
||||
|
@ -164,7 +164,7 @@ class S3FilesImpl(Files):
|
|||
self._client = _create_s3_client(self._config)
|
||||
await _create_bucket_if_not_exists(self._client, self._config)
|
||||
|
||||
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store))
|
||||
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
|
||||
await self._sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
|
@ -268,7 +268,6 @@ class S3FilesImpl(Files):
|
|||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
policy=self.policy,
|
||||
where=where_conditions,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
|
|
|
@ -54,7 +54,7 @@ class InferenceStore:
|
|||
|
||||
async def initialize(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
|
||||
await self.sql_store.create_table(
|
||||
"chat_completions",
|
||||
{
|
||||
|
@ -202,7 +202,6 @@ class InferenceStore:
|
|||
order_by=[("created", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
policy=self.policy,
|
||||
)
|
||||
|
||||
data = [
|
||||
|
@ -229,7 +228,6 @@ class InferenceStore:
|
|||
row = await self.sql_store.fetch_one(
|
||||
table="chat_completions",
|
||||
where={"id": completion_id},
|
||||
policy=self.policy,
|
||||
)
|
||||
|
||||
if not row:
|
||||
|
|
|
@ -28,8 +28,7 @@ class ResponsesStore:
|
|||
sql_store_config = SqliteSqlStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||
)
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config))
|
||||
self.policy = policy
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy)
|
||||
|
||||
async def initialize(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
|
@ -87,7 +86,6 @@ class ResponsesStore:
|
|||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
policy=self.policy,
|
||||
)
|
||||
|
||||
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
|
||||
|
@ -105,7 +103,6 @@ class ResponsesStore:
|
|||
row = await self.sql_store.fetch_one(
|
||||
"openai_responses",
|
||||
where={"id": response_id},
|
||||
policy=self.policy,
|
||||
)
|
||||
|
||||
if not row:
|
||||
|
@ -116,7 +113,7 @@ class ResponsesStore:
|
|||
return OpenAIResponseObjectWithInput(**row["response_object"])
|
||||
|
||||
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy)
|
||||
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
|
||||
if not row:
|
||||
raise ValueError(f"Response with id {response_id} not found")
|
||||
await self.sql_store.delete("openai_responses", where={"id": response_id})
|
||||
|
|
|
@ -53,13 +53,15 @@ class AuthorizedSqlStore:
|
|||
access control policies, user attribute capture, and SQL filtering optimization.
|
||||
"""
|
||||
|
||||
def __init__(self, sql_store: SqlStore):
|
||||
def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
|
||||
"""
|
||||
Initialize the authorization layer.
|
||||
|
||||
:param sql_store: Base SqlStore implementation to wrap
|
||||
:param policy: Access control policy to use for authorization
|
||||
"""
|
||||
self.sql_store = sql_store
|
||||
self.policy = policy
|
||||
self._detect_database_type()
|
||||
self._validate_sql_optimized_policy()
|
||||
|
||||
|
@ -117,14 +119,13 @@ class AuthorizedSqlStore:
|
|||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
policy: list[AccessRule],
|
||||
where: Mapping[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
cursor: tuple[str, str] | None = None,
|
||||
) -> PaginatedResponse:
|
||||
"""Fetch all rows with automatic access control filtering."""
|
||||
access_where = self._build_access_control_where_clause(policy)
|
||||
access_where = self._build_access_control_where_clause(self.policy)
|
||||
rows = await self.sql_store.fetch_all(
|
||||
table=table,
|
||||
where=where,
|
||||
|
@ -146,7 +147,7 @@ class AuthorizedSqlStore:
|
|||
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
|
||||
)
|
||||
|
||||
if is_action_allowed(policy, Action.READ, sql_record, current_user):
|
||||
if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
|
||||
filtered_rows.append(row)
|
||||
|
||||
return PaginatedResponse(
|
||||
|
@ -157,14 +158,12 @@ class AuthorizedSqlStore:
|
|||
async def fetch_one(
|
||||
self,
|
||||
table: str,
|
||||
policy: list[AccessRule],
|
||||
where: Mapping[str, Any] | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch one row with automatic access control checking."""
|
||||
results = await self.fetch_all(
|
||||
table=table,
|
||||
policy=policy,
|
||||
where=where,
|
||||
limit=1,
|
||||
order_by=order_by,
|
||||
|
|
1
tests/external/kaze.yaml
vendored
1
tests/external/kaze.yaml
vendored
|
@ -1,4 +1,3 @@
|
|||
adapter:
|
||||
adapter_type: kaze
|
||||
pip_packages: ["tests/external/llama-stack-provider-kaze"]
|
||||
config_class: llama_stack_provider_kaze.config.KazeProviderConfig
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Protocol
|
||||
|
||||
from llama_stack.providers.datatypes import AdapterSpec, Api, ProviderSpec, RemoteProviderSpec
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec, RemoteProviderSpec
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
|
@ -16,12 +16,9 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api=Api.weather,
|
||||
provider_type="remote::kaze",
|
||||
config_class="llama_stack_provider_kaze.KazeProviderConfig",
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="kaze",
|
||||
module="llama_stack_provider_kaze",
|
||||
pip_packages=["llama_stack_provider_kaze"],
|
||||
config_class="llama_stack_provider_kaze.KazeProviderConfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ def authorized_store(backend_config):
|
|||
config = config_func()
|
||||
|
||||
base_sqlstore = sqlstore_impl(config)
|
||||
authorized_store = AuthorizedSqlStore(base_sqlstore)
|
||||
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||
|
||||
yield authorized_store
|
||||
|
||||
|
@ -106,7 +106,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
|||
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
|
||||
|
||||
# Test fetching with no user - should not error on JSON comparison
|
||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
||||
result = await authorized_store.fetch_all(table_name)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0]["id"] == "1"
|
||||
assert result.data[0]["access_attributes"] is None
|
||||
|
@ -119,7 +119,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
|||
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
|
||||
|
||||
# Fetch all - admin should see both
|
||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
||||
result = await authorized_store.fetch_all(table_name)
|
||||
assert len(result.data) == 2
|
||||
|
||||
# Test with non-admin user
|
||||
|
@ -127,7 +127,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
|||
mock_get_authenticated_user.return_value = regular_user
|
||||
|
||||
# Should only see public record
|
||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
||||
result = await authorized_store.fetch_all(table_name)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0]["id"] == "1"
|
||||
|
||||
|
@ -156,7 +156,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
|||
|
||||
# Now test with the multi-user who has both roles=admin and teams=dev
|
||||
mock_get_authenticated_user.return_value = multi_user
|
||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
||||
result = await authorized_store.fetch_all(table_name)
|
||||
|
||||
# Should see:
|
||||
# - public record (1) - no access_attributes
|
||||
|
@ -217,21 +217,24 @@ async def test_user_ownership_policy(mock_get_authenticated_user, authorized_sto
|
|||
),
|
||||
]
|
||||
|
||||
# Create a new authorized store with the owner-only policy
|
||||
owner_only_store = AuthorizedSqlStore(authorized_store.sql_store, owner_only_policy)
|
||||
|
||||
# Test user1 access - should only see their own record
|
||||
mock_get_authenticated_user.return_value = user1
|
||||
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
|
||||
result = await owner_only_store.fetch_all(table_name)
|
||||
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
|
||||
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
|
||||
|
||||
# Test user2 access - should only see their own record
|
||||
mock_get_authenticated_user.return_value = user2
|
||||
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
|
||||
result = await owner_only_store.fetch_all(table_name)
|
||||
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
|
||||
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
|
||||
|
||||
# Test with anonymous user - should see no records
|
||||
mock_get_authenticated_user.return_value = None
|
||||
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
|
||||
result = await owner_only_store.fetch_all(table_name)
|
||||
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
|
||||
|
||||
finally:
|
||||
|
|
|
@ -66,7 +66,6 @@ def base_config(tmp_path):
|
|||
def provider_spec_yaml():
|
||||
"""Common provider spec YAML for testing."""
|
||||
return """
|
||||
adapter:
|
||||
adapter_type: test_provider
|
||||
config_class: test_provider.config.TestProviderConfig
|
||||
module: test_provider
|
||||
|
@ -182,9 +181,9 @@ class TestProviderRegistry:
|
|||
assert Api.inference in registry
|
||||
assert "remote::test_provider" in registry[Api.inference]
|
||||
provider = registry[Api.inference]["remote::test_provider"]
|
||||
assert provider.adapter.adapter_type == "test_provider"
|
||||
assert provider.adapter.module == "test_provider"
|
||||
assert provider.adapter.config_class == "test_provider.config.TestProviderConfig"
|
||||
assert provider.adapter_type == "test_provider"
|
||||
assert provider.module == "test_provider"
|
||||
assert provider.config_class == "test_provider.config.TestProviderConfig"
|
||||
assert Api.safety in provider.api_dependencies
|
||||
|
||||
def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml):
|
||||
|
@ -246,7 +245,6 @@ class TestProviderRegistry:
|
|||
"""Test handling of malformed remote provider spec (missing required fields)."""
|
||||
remote_dir, _ = api_directories
|
||||
malformed_spec = """
|
||||
adapter:
|
||||
adapter_type: test_provider
|
||||
# Missing required fields
|
||||
api_dependencies:
|
||||
|
@ -270,7 +268,7 @@ pip_packages:
|
|||
with open(inline_dir / "malformed.yaml", "w") as f:
|
||||
f.write(malformed_spec)
|
||||
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
get_provider_registry(base_config)
|
||||
assert "config_class" in str(exc_info.value)
|
||||
|
||||
|
|
|
@ -27,13 +27,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
|||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
class MockStack:
|
||||
def __init__(self, config, custom_provider_registry=None):
|
||||
self.impls = mock_impls
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
client = LlamaStackAsLibraryClient("ci-tests")
|
||||
|
@ -46,13 +50,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
|||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
class MockStack:
|
||||
def __init__(self, config, custom_provider_registry=None):
|
||||
self.impls = mock_impls
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||
|
@ -68,13 +76,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
|||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
class MockStack:
|
||||
def __init__(self, config, custom_provider_registry=None):
|
||||
self.impls = mock_impls
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
client = LlamaStackAsLibraryClient("ci-tests")
|
||||
|
@ -90,13 +102,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
|||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
class MockStack:
|
||||
def __init__(self, config, custom_provider_registry=None):
|
||||
self.impls = mock_impls
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||
|
@ -112,13 +128,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
|||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
class MockStack:
|
||||
def __init__(self, config, custom_provider_registry=None):
|
||||
self.impls = mock_impls
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
sync_client = LlamaStackAsLibraryClient("ci-tests")
|
||||
|
|
|
@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
|
|||
db_path=tmp_dir + "/" + db_name,
|
||||
)
|
||||
)
|
||||
sqlstore = AuthorizedSqlStore(base_sqlstore)
|
||||
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||
|
||||
# Create table with access control
|
||||
await sqlstore.create_table(
|
||||
|
@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
|
|||
mock_get_authenticated_user.return_value = admin_user
|
||||
|
||||
# Admin should see both documents
|
||||
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
|
||||
result = await sqlstore.fetch_all("documents", where={"id": 1})
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0]["title"] == "Admin Document"
|
||||
|
||||
# User should only see their document
|
||||
mock_get_authenticated_user.return_value = regular_user
|
||||
|
||||
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
|
||||
result = await sqlstore.fetch_all("documents", where={"id": 1})
|
||||
assert len(result.data) == 0
|
||||
|
||||
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2})
|
||||
result = await sqlstore.fetch_all("documents", where={"id": 2})
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0]["title"] == "User Document"
|
||||
|
||||
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1})
|
||||
row = await sqlstore.fetch_one("documents", where={"id": 1})
|
||||
assert row is None
|
||||
|
||||
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2})
|
||||
row = await sqlstore.fetch_one("documents", where={"id": 2})
|
||||
assert row is not None
|
||||
assert row["title"] == "User Document"
|
||||
|
||||
|
@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
|||
db_path=tmp_dir + "/" + db_name,
|
||||
)
|
||||
)
|
||||
sqlstore = AuthorizedSqlStore(base_sqlstore)
|
||||
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||
|
||||
await sqlstore.create_table(
|
||||
table="resources",
|
||||
|
@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
|||
user = User(principal=user_data["principal"], attributes=user_data["attributes"])
|
||||
mock_get_authenticated_user.return_value = user
|
||||
|
||||
sql_results = await sqlstore.fetch_all("resources", policy=policy)
|
||||
sql_results = await sqlstore.fetch_all("resources")
|
||||
sql_ids = {row["id"] for row in sql_results.data}
|
||||
policy_ids = set()
|
||||
for scenario in test_scenarios:
|
||||
|
@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us
|
|||
db_path=tmp_dir + "/" + db_name,
|
||||
)
|
||||
)
|
||||
authorized_store = AuthorizedSqlStore(base_sqlstore)
|
||||
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||
|
||||
await authorized_store.create_table(
|
||||
table="user_data",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue