mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Merge branch 'main' into use-openai-for-cerebras
This commit is contained in:
commit
9ceb45f611
53 changed files with 2612 additions and 1966 deletions
2
.github/workflows/python-build-test.yml
vendored
2
.github/workflows/python-build-test.yml
vendored
|
@ -24,7 +24,7 @@ jobs:
|
||||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@557e51de59eb14aaaba2ed9621916900a91d50c6 # v6.6.1
|
uses: astral-sh/setup-uv@b75a909f75acd358c2196fb9a5f1299a9a8868a4 # v6.7.0
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
activate-environment: true
|
activate-environment: true
|
||||||
|
|
|
@ -17,11 +17,8 @@ export POSTGRES_PASSWORD=llamastack
|
||||||
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||||
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||||
|
|
||||||
export MOCK_INFERENCE_MODEL=mock-inference
|
|
||||||
|
|
||||||
export MOCK_INFERENCE_URL=openai-mock-service:8080
|
|
||||||
|
|
||||||
export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL
|
export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL
|
||||||
|
export LLAMA_STACK_WORKERS=4
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
set -x
|
set -x
|
||||||
|
|
|
@ -5,6 +5,7 @@ data:
|
||||||
image_name: kubernetes-benchmark-demo
|
image_name: kubernetes-benchmark-demo
|
||||||
apis:
|
apis:
|
||||||
- agents
|
- agents
|
||||||
|
- files
|
||||||
- inference
|
- inference
|
||||||
- files
|
- files
|
||||||
- safety
|
- safety
|
||||||
|
@ -23,6 +24,14 @@ data:
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
provider_type: inline::sentence-transformers
|
provider_type: inline::sentence-transformers
|
||||||
config: {}
|
config: {}
|
||||||
|
files:
|
||||||
|
- provider_id: meta-reference-files
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
|
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
|
||||||
provider_type: remote::chromadb
|
provider_type: remote::chromadb
|
||||||
|
|
|
@ -52,9 +52,20 @@ spec:
|
||||||
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
|
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
|
||||||
- name: VLLM_TLS_VERIFY
|
- name: VLLM_TLS_VERIFY
|
||||||
value: "false"
|
value: "false"
|
||||||
command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"]
|
- name: LLAMA_STACK_LOGGING
|
||||||
|
value: "all=WARNING"
|
||||||
|
- name: LLAMA_STACK_CONFIG
|
||||||
|
value: "/etc/config/stack_run_config.yaml"
|
||||||
|
- name: LLAMA_STACK_WORKERS
|
||||||
|
value: "${LLAMA_STACK_WORKERS}"
|
||||||
|
command: ["uvicorn", "llama_stack.core.server.server:create_app", "--host", "0.0.0.0", "--port", "8323", "--workers", "$LLAMA_STACK_WORKERS", "--factory"]
|
||||||
ports:
|
ports:
|
||||||
- containerPort: 8323
|
- containerPort: 8323
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
cpu: "${LLAMA_STACK_WORKERS}"
|
||||||
|
limits:
|
||||||
|
cpu: "${LLAMA_STACK_WORKERS}"
|
||||||
volumeMounts:
|
volumeMounts:
|
||||||
- name: llama-storage
|
- name: llama-storage
|
||||||
mountPath: /root/.llama
|
mountPath: /root/.llama
|
||||||
|
|
|
@ -11,6 +11,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| datasetio | `inline::localfs`, `remote::nvidia` |
|
| datasetio | `inline::localfs`, `remote::nvidia` |
|
||||||
| eval | `remote::nvidia` |
|
| eval | `remote::nvidia` |
|
||||||
|
| files | `inline::localfs` |
|
||||||
| inference | `remote::nvidia` |
|
| inference | `remote::nvidia` |
|
||||||
| post_training | `remote::nvidia` |
|
| post_training | `remote::nvidia` |
|
||||||
| safety | `remote::nvidia` |
|
| safety | `remote::nvidia` |
|
||||||
|
|
|
@ -460,10 +460,12 @@ client = LlamaStackClient(base_url="http://localhost:8321")
|
||||||
embed_lm = next(m for m in client.models.list() if m.model_type == "embedding")
|
embed_lm = next(m for m in client.models.list() if m.model_type == "embedding")
|
||||||
embedding_model = embed_lm.identifier
|
embedding_model = embed_lm.identifier
|
||||||
vector_db_id = f"v{uuid.uuid4().hex}"
|
vector_db_id = f"v{uuid.uuid4().hex}"
|
||||||
client.vector_dbs.register(
|
# The VectorDB API is deprecated; the server now returns its own authoritative ID.
|
||||||
|
# We capture the correct ID from the response's .identifier attribute.
|
||||||
|
vector_db_id = client.vector_dbs.register(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
)
|
).identifier
|
||||||
|
|
||||||
# Create Documents
|
# Create Documents
|
||||||
urls = [
|
urls = [
|
||||||
|
|
|
@ -23,7 +23,13 @@ To use Milvus in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
You can install Milvus using pymilvus:
|
If you want to use inline Milvus, you can install:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install pymilvus[milvus-lite]
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to use remote Milvus, you can install:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install pymilvus
|
pip install pymilvus
|
||||||
|
|
|
@ -147,7 +147,7 @@ WORKDIR /app
|
||||||
|
|
||||||
RUN dnf -y update && dnf install -y iputils git net-tools wget \
|
RUN dnf -y update && dnf install -y iputils git net-tools wget \
|
||||||
vim-minimal python3.12 python3.12-pip python3.12-wheel \
|
vim-minimal python3.12 python3.12-pip python3.12-wheel \
|
||||||
python3.12-setuptools python3.12-devel gcc make && \
|
python3.12-setuptools python3.12-devel gcc gcc-c++ make && \
|
||||||
ln -s /bin/pip3.12 /bin/pip && ln -s /bin/python3.12 /bin/python && dnf clean all
|
ln -s /bin/pip3.12 /bin/pip && ln -s /bin/python3.12 /bin/python && dnf clean all
|
||||||
|
|
||||||
ENV UV_SYSTEM_PYTHON=1
|
ENV UV_SYSTEM_PYTHON=1
|
||||||
|
@ -164,7 +164,7 @@ RUN apt-get update && apt-get install -y \
|
||||||
procps psmisc lsof \
|
procps psmisc lsof \
|
||||||
traceroute \
|
traceroute \
|
||||||
bubblewrap \
|
bubblewrap \
|
||||||
gcc \
|
gcc g++ \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
ENV UV_SYSTEM_PYTHON=1
|
ENV UV_SYSTEM_PYTHON=1
|
||||||
|
|
|
@ -121,10 +121,6 @@ class AutoRoutedProviderSpec(ProviderSpec):
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def pip_packages(self) -> list[str]:
|
|
||||||
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
|
|
||||||
|
|
||||||
|
|
||||||
# Example: /models, /shields
|
# Example: /models, /shields
|
||||||
class RoutingTableProviderSpec(ProviderSpec):
|
class RoutingTableProviderSpec(ProviderSpec):
|
||||||
|
|
|
@ -16,11 +16,10 @@ from llama_stack.core.datatypes import BuildConfig, DistributionSpec
|
||||||
from llama_stack.core.external import load_external_apis
|
from llama_stack.core.external import load_external_apis
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
AdapterSpec,
|
|
||||||
Api,
|
Api,
|
||||||
InlineProviderSpec,
|
InlineProviderSpec,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
remote_provider_spec,
|
RemoteProviderSpec,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
@ -77,27 +76,12 @@ def providable_apis() -> list[Api]:
|
||||||
|
|
||||||
|
|
||||||
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
|
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
|
||||||
adapter = AdapterSpec(**spec_data["adapter"])
|
spec = RemoteProviderSpec(api=api, provider_type=f"remote::{spec_data['adapter_type']}", **spec_data)
|
||||||
spec = remote_provider_spec(
|
|
||||||
api=api,
|
|
||||||
adapter=adapter,
|
|
||||||
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
|
||||||
)
|
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
||||||
spec = InlineProviderSpec(
|
spec = InlineProviderSpec(api=api, provider_type=f"inline::{provider_name}", **spec_data)
|
||||||
api=api,
|
|
||||||
provider_type=f"inline::{provider_name}",
|
|
||||||
pip_packages=spec_data.get("pip_packages", []),
|
|
||||||
module=spec_data["module"],
|
|
||||||
config_class=spec_data["config_class"],
|
|
||||||
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
|
||||||
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
|
|
||||||
provider_data_validator=spec_data.get("provider_data_validator"),
|
|
||||||
container_image=spec_data.get("container_image"),
|
|
||||||
)
|
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ from llama_stack.core.request_headers import (
|
||||||
from llama_stack.core.resolver import ProviderRegistry
|
from llama_stack.core.resolver import ProviderRegistry
|
||||||
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
|
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
|
||||||
from llama_stack.core.stack import (
|
from llama_stack.core.stack import (
|
||||||
construct_stack,
|
Stack,
|
||||||
get_stack_run_config_from_distro,
|
get_stack_run_config_from_distro,
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
)
|
)
|
||||||
|
@ -252,7 +252,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.route_impls = None
|
self.route_impls = None
|
||||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
|
||||||
|
stack = Stack(self.config, self.custom_provider_registry)
|
||||||
|
await stack.initialize()
|
||||||
|
self.impls = stack.impls
|
||||||
except ModuleNotFoundError as _e:
|
except ModuleNotFoundError as _e:
|
||||||
cprint(_e.msg, color="red", file=sys.stderr)
|
cprint(_e.msg, color="red", file=sys.stderr)
|
||||||
cprint(
|
cprint(
|
||||||
|
@ -289,6 +292,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
)
|
)
|
||||||
raise _e
|
raise _e
|
||||||
|
|
||||||
|
assert self.impls is not None
|
||||||
if Api.telemetry in self.impls:
|
if Api.telemetry in self.impls:
|
||||||
setup_logger(self.impls[Api.telemetry])
|
setup_logger(self.impls[Api.telemetry])
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
try:
|
try:
|
||||||
models = await provider.list_models()
|
models = await provider.list_models()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
|
logger.warning(f"Model refresh failed for provider {provider_id}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.listed_providers.add(provider_id)
|
self.listed_providers.add(provider_id)
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
@ -50,17 +51,15 @@ from llama_stack.core.request_headers import (
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
user_from_scope,
|
user_from_scope,
|
||||||
)
|
)
|
||||||
from llama_stack.core.resolver import InvalidProviderError
|
|
||||||
from llama_stack.core.server.routes import (
|
from llama_stack.core.server.routes import (
|
||||||
find_matching_route,
|
find_matching_route,
|
||||||
get_all_api_routes,
|
get_all_api_routes,
|
||||||
initialize_route_impls,
|
initialize_route_impls,
|
||||||
)
|
)
|
||||||
from llama_stack.core.stack import (
|
from llama_stack.core.stack import (
|
||||||
|
Stack,
|
||||||
cast_image_name_to_string,
|
cast_image_name_to_string,
|
||||||
construct_stack,
|
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
shutdown_stack,
|
|
||||||
validate_env_pair,
|
validate_env_pair,
|
||||||
)
|
)
|
||||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||||
|
@ -156,21 +155,34 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def shutdown(app):
|
class StackApp(FastAPI):
|
||||||
"""Initiate a graceful shutdown of the application.
|
|
||||||
|
|
||||||
Handled by the lifespan context manager. The shutdown process involves
|
|
||||||
shutting down all implementations registered in the application.
|
|
||||||
"""
|
"""
|
||||||
await shutdown_stack(app.__llama_stack_impls__)
|
A wrapper around the FastAPI application to hold a reference to the Stack instance so that we can
|
||||||
|
start background tasks (e.g. refresh model registry periodically) from the lifespan context manager.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: StackRunConfig, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.stack: Stack = Stack(config)
|
||||||
|
|
||||||
|
# This code is called from a running event loop managed by uvicorn so we cannot simply call
|
||||||
|
# asyncio.run() to initialize the stack. We cannot await either since this is not an async
|
||||||
|
# function.
|
||||||
|
# As a workaround, we use a thread pool executor to run the initialize() method
|
||||||
|
# in a separate thread.
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
future = executor.submit(asyncio.run, self.stack.initialize())
|
||||||
|
future.result()
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: StackApp):
|
||||||
logger.info("Starting up")
|
logger.info("Starting up")
|
||||||
|
assert app.stack is not None
|
||||||
|
app.stack.create_registry_refresh_task()
|
||||||
yield
|
yield
|
||||||
logger.info("Shutting down")
|
logger.info("Shutting down")
|
||||||
await shutdown(app)
|
await app.stack.shutdown()
|
||||||
|
|
||||||
|
|
||||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||||
|
@ -386,73 +398,61 @@ class ClientVersionMiddleware:
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace | None = None):
|
def create_app(
|
||||||
"""Start the LlamaStack server."""
|
config_file: str | None = None,
|
||||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
env_vars: list[str] | None = None,
|
||||||
|
) -> StackApp:
|
||||||
|
"""Create and configure the FastAPI application.
|
||||||
|
|
||||||
add_config_distro_args(parser)
|
Args:
|
||||||
parser.add_argument(
|
config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
|
||||||
"--port",
|
env_vars: List of environment variables in KEY=value format.
|
||||||
type=int,
|
disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
|
||||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
|
||||||
help="Port to listen on",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--env",
|
|
||||||
action="append",
|
|
||||||
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine whether the server args are being passed by the "run" command, if this is the case
|
Returns:
|
||||||
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
Configured StackApp instance.
|
||||||
# parsed from the command line
|
"""
|
||||||
if args is None:
|
config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
|
||||||
args = parser.parse_args()
|
if config_file is None:
|
||||||
|
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
|
||||||
|
|
||||||
config_or_distro = get_config_from_args(args)
|
config_file = resolve_config_or_distro(config_file, Mode.RUN)
|
||||||
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
|
||||||
|
|
||||||
|
# Load and process configuration
|
||||||
logger_config = None
|
logger_config = None
|
||||||
with open(config_file) as fp:
|
with open(config_file) as fp:
|
||||||
config_contents = yaml.safe_load(fp)
|
config_contents = yaml.safe_load(fp)
|
||||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||||
logger_config = LoggingConfig(**cfg)
|
logger_config = LoggingConfig(**cfg)
|
||||||
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
||||||
if args.env:
|
|
||||||
for env_pair in args.env:
|
if env_vars:
|
||||||
|
for env_pair in env_vars:
|
||||||
try:
|
try:
|
||||||
key, value = validate_env_pair(env_pair)
|
key, value = validate_env_pair(env_pair)
|
||||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
logger.info(f"Setting environment variable {key} => {value}")
|
||||||
os.environ[key] = value
|
os.environ[key] = value
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Error: {str(e)}")
|
logger.error(f"Error: {str(e)}")
|
||||||
sys.exit(1)
|
raise ValueError(f"Invalid environment variable format: {env_pair}") from e
|
||||||
|
|
||||||
config = replace_env_vars(config_contents)
|
config = replace_env_vars(config_contents)
|
||||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||||
|
|
||||||
_log_run_config(run_config=config)
|
_log_run_config(run_config=config)
|
||||||
|
|
||||||
app = FastAPI(
|
app = StackApp(
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
docs_url="/docs",
|
docs_url="/docs",
|
||||||
redoc_url="/redoc",
|
redoc_url="/redoc",
|
||||||
openapi_url="/openapi.json",
|
openapi_url="/openapi.json",
|
||||||
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||||
app.add_middleware(ClientVersionMiddleware)
|
app.add_middleware(ClientVersionMiddleware)
|
||||||
|
|
||||||
try:
|
impls = app.stack.impls
|
||||||
# Create and set the event loop that will be used for both construction and server runtime
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
# Construct the stack in the persistent event loop
|
|
||||||
impls = loop.run_until_complete(construct_stack(config))
|
|
||||||
|
|
||||||
except InvalidProviderError as e:
|
|
||||||
logger.error(f"Error: {str(e)}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if config.server.auth:
|
if config.server.auth:
|
||||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
||||||
|
@ -553,9 +553,54 @@ def main(args: argparse.Namespace | None = None):
|
||||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||||
app.exception_handler(Exception)(global_exception_handler)
|
app.exception_handler(Exception)(global_exception_handler)
|
||||||
|
|
||||||
app.__llama_stack_impls__ = impls
|
|
||||||
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace | None = None):
|
||||||
|
"""Start the LlamaStack server."""
|
||||||
|
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||||
|
|
||||||
|
add_config_distro_args(parser)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||||
|
help="Port to listen on",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--env",
|
||||||
|
action="append",
|
||||||
|
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine whether the server args are being passed by the "run" command, if this is the case
|
||||||
|
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
||||||
|
# parsed from the command line
|
||||||
|
if args is None:
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
config_or_distro = get_config_from_args(args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
app = create_app(
|
||||||
|
config_file=config_or_distro,
|
||||||
|
env_vars=args.env,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating app: {str(e)}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
||||||
|
with open(config_file) as fp:
|
||||||
|
config_contents = yaml.safe_load(fp)
|
||||||
|
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||||
|
logger_config = LoggingConfig(**cfg)
|
||||||
|
else:
|
||||||
|
logger_config = None
|
||||||
|
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
# Configure SSL if certificates are provided
|
# Configure SSL if certificates are provided
|
||||||
|
@ -593,7 +638,6 @@ def main(args: argparse.Namespace | None = None):
|
||||||
if ssl_config:
|
if ssl_config:
|
||||||
uvicorn_config.update(ssl_config)
|
uvicorn_config.update(ssl_config)
|
||||||
|
|
||||||
# Run uvicorn in the existing event loop to preserve background tasks
|
|
||||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||||
|
@ -604,13 +648,9 @@ def main(args: argparse.Namespace | None = None):
|
||||||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||||
# signal handling but this is quite intrusive and not worth the effort.
|
# signal handling but this is quite intrusive and not worth the effort.
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||||
except (KeyboardInterrupt, SystemExit):
|
except (KeyboardInterrupt, SystemExit):
|
||||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||||
finally:
|
|
||||||
if not loop.is_closed():
|
|
||||||
logger.debug("Closing event loop")
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
def _log_run_config(run_config: StackRunConfig):
|
def _log_run_config(run_config: StackRunConfig):
|
||||||
|
|
|
@ -315,78 +315,84 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
||||||
impls[Api.prompts] = prompts_impl
|
impls[Api.prompts] = prompts_impl
|
||||||
|
|
||||||
|
|
||||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
class Stack:
|
||||||
# asked for in the run config.
|
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
||||||
async def construct_stack(
|
self.run_config = run_config
|
||||||
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
self.provider_registry = provider_registry
|
||||||
) -> dict[Api, Any]:
|
self.impls = None
|
||||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
|
||||||
from llama_stack.testing.inference_recorder import setup_inference_recording
|
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||||
|
# asked for in the run config.
|
||||||
|
async def initialize(self):
|
||||||
|
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
||||||
|
from llama_stack.testing.inference_recorder import setup_inference_recording
|
||||||
|
|
||||||
|
global TEST_RECORDING_CONTEXT
|
||||||
|
TEST_RECORDING_CONTEXT = setup_inference_recording()
|
||||||
|
if TEST_RECORDING_CONTEXT:
|
||||||
|
TEST_RECORDING_CONTEXT.__enter__()
|
||||||
|
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||||
|
|
||||||
|
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
|
||||||
|
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
|
||||||
|
impls = await resolve_impls(
|
||||||
|
self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add internal implementations after all other providers are resolved
|
||||||
|
add_internal_implementations(impls, self.run_config)
|
||||||
|
|
||||||
|
if Api.prompts in impls:
|
||||||
|
await impls[Api.prompts].initialize()
|
||||||
|
|
||||||
|
await register_resources(self.run_config, impls)
|
||||||
|
|
||||||
|
await refresh_registry_once(impls)
|
||||||
|
self.impls = impls
|
||||||
|
|
||||||
|
def create_registry_refresh_task(self):
|
||||||
|
assert self.impls is not None, "Must call initialize() before starting"
|
||||||
|
|
||||||
|
global REGISTRY_REFRESH_TASK
|
||||||
|
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls))
|
||||||
|
|
||||||
|
def cb(task):
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
if task.cancelled():
|
||||||
|
logger.error("Model refresh task cancelled")
|
||||||
|
elif task.exception():
|
||||||
|
logger.error(f"Model refresh task failed: {task.exception()}")
|
||||||
|
traceback.print_exception(task.exception())
|
||||||
|
else:
|
||||||
|
logger.debug("Model refresh task completed")
|
||||||
|
|
||||||
|
REGISTRY_REFRESH_TASK.add_done_callback(cb)
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
for impl in self.impls.values():
|
||||||
|
impl_name = impl.__class__.__name__
|
||||||
|
logger.info(f"Shutting down {impl_name}")
|
||||||
|
try:
|
||||||
|
if hasattr(impl, "shutdown"):
|
||||||
|
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||||
|
else:
|
||||||
|
logger.warning(f"No shutdown method for {impl_name}")
|
||||||
|
except TimeoutError:
|
||||||
|
logger.exception(f"Shutdown timeout for {impl_name}")
|
||||||
|
except (Exception, asyncio.CancelledError) as e:
|
||||||
|
logger.exception(f"Failed to shutdown {impl_name}: {e}")
|
||||||
|
|
||||||
global TEST_RECORDING_CONTEXT
|
global TEST_RECORDING_CONTEXT
|
||||||
TEST_RECORDING_CONTEXT = setup_inference_recording()
|
|
||||||
if TEST_RECORDING_CONTEXT:
|
if TEST_RECORDING_CONTEXT:
|
||||||
TEST_RECORDING_CONTEXT.__enter__()
|
try:
|
||||||
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during inference recording cleanup: {e}")
|
||||||
|
|
||||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
global REGISTRY_REFRESH_TASK
|
||||||
policy = run_config.server.auth.access_policy if run_config.server.auth else []
|
if REGISTRY_REFRESH_TASK:
|
||||||
impls = await resolve_impls(
|
REGISTRY_REFRESH_TASK.cancel()
|
||||||
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add internal implementations after all other providers are resolved
|
|
||||||
add_internal_implementations(impls, run_config)
|
|
||||||
|
|
||||||
if Api.prompts in impls:
|
|
||||||
await impls[Api.prompts].initialize()
|
|
||||||
|
|
||||||
await register_resources(run_config, impls)
|
|
||||||
|
|
||||||
await refresh_registry_once(impls)
|
|
||||||
|
|
||||||
global REGISTRY_REFRESH_TASK
|
|
||||||
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
|
|
||||||
|
|
||||||
def cb(task):
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
if task.cancelled():
|
|
||||||
logger.error("Model refresh task cancelled")
|
|
||||||
elif task.exception():
|
|
||||||
logger.error(f"Model refresh task failed: {task.exception()}")
|
|
||||||
traceback.print_exception(task.exception())
|
|
||||||
else:
|
|
||||||
logger.debug("Model refresh task completed")
|
|
||||||
|
|
||||||
REGISTRY_REFRESH_TASK.add_done_callback(cb)
|
|
||||||
return impls
|
|
||||||
|
|
||||||
|
|
||||||
async def shutdown_stack(impls: dict[Api, Any]):
|
|
||||||
for impl in impls.values():
|
|
||||||
impl_name = impl.__class__.__name__
|
|
||||||
logger.info(f"Shutting down {impl_name}")
|
|
||||||
try:
|
|
||||||
if hasattr(impl, "shutdown"):
|
|
||||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
|
||||||
else:
|
|
||||||
logger.warning(f"No shutdown method for {impl_name}")
|
|
||||||
except TimeoutError:
|
|
||||||
logger.exception(f"Shutdown timeout for {impl_name}")
|
|
||||||
except (Exception, asyncio.CancelledError) as e:
|
|
||||||
logger.exception(f"Failed to shutdown {impl_name}: {e}")
|
|
||||||
|
|
||||||
global TEST_RECORDING_CONTEXT
|
|
||||||
if TEST_RECORDING_CONTEXT:
|
|
||||||
try:
|
|
||||||
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during inference recording cleanup: {e}")
|
|
||||||
|
|
||||||
global REGISTRY_REFRESH_TASK
|
|
||||||
if REGISTRY_REFRESH_TASK:
|
|
||||||
REGISTRY_REFRESH_TASK.cancel()
|
|
||||||
|
|
||||||
|
|
||||||
async def refresh_registry_once(impls: dict[Api, Any]):
|
async def refresh_registry_once(impls: dict[Api, Any]):
|
||||||
|
|
|
@ -23,6 +23,8 @@ distribution_spec:
|
||||||
- provider_type: inline::basic
|
- provider_type: inline::basic
|
||||||
tool_runtime:
|
tool_runtime:
|
||||||
- provider_type: inline::rag-runtime
|
- provider_type: inline::rag-runtime
|
||||||
|
files:
|
||||||
|
- provider_type: inline::localfs
|
||||||
image_type: venv
|
image_type: venv
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
- aiosqlite
|
- aiosqlite
|
||||||
|
|
|
@ -8,6 +8,7 @@ from pathlib import Path
|
||||||
|
|
||||||
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ShieldInput, ToolGroupInput
|
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ShieldInput, ToolGroupInput
|
||||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||||
|
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||||
from llama_stack.providers.remote.datasetio.nvidia import NvidiaDatasetIOConfig
|
from llama_stack.providers.remote.datasetio.nvidia import NvidiaDatasetIOConfig
|
||||||
from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig
|
from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig
|
||||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||||
|
@ -15,7 +16,7 @@ from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
|
||||||
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
|
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
|
||||||
|
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template(name: str = "nvidia") -> DistributionTemplate:
|
||||||
providers = {
|
providers = {
|
||||||
"inference": [BuildProvider(provider_type="remote::nvidia")],
|
"inference": [BuildProvider(provider_type="remote::nvidia")],
|
||||||
"vector_io": [BuildProvider(provider_type="inline::faiss")],
|
"vector_io": [BuildProvider(provider_type="inline::faiss")],
|
||||||
|
@ -30,6 +31,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
],
|
],
|
||||||
"scoring": [BuildProvider(provider_type="inline::basic")],
|
"scoring": [BuildProvider(provider_type="inline::basic")],
|
||||||
"tool_runtime": [BuildProvider(provider_type="inline::rag-runtime")],
|
"tool_runtime": [BuildProvider(provider_type="inline::rag-runtime")],
|
||||||
|
"files": [BuildProvider(provider_type="inline::localfs")],
|
||||||
}
|
}
|
||||||
|
|
||||||
inference_provider = Provider(
|
inference_provider = Provider(
|
||||||
|
@ -52,6 +54,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_type="remote::nvidia",
|
provider_type="remote::nvidia",
|
||||||
config=NVIDIAEvalConfig.sample_run_config(),
|
config=NVIDIAEvalConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
|
files_provider = Provider(
|
||||||
|
provider_id="meta-reference-files",
|
||||||
|
provider_type="inline::localfs",
|
||||||
|
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
|
)
|
||||||
inference_model = ModelInput(
|
inference_model = ModelInput(
|
||||||
model_id="${env.INFERENCE_MODEL}",
|
model_id="${env.INFERENCE_MODEL}",
|
||||||
provider_id="nvidia",
|
provider_id="nvidia",
|
||||||
|
@ -73,7 +80,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
|
|
||||||
default_models, _ = get_model_registry(available_models)
|
default_models, _ = get_model_registry(available_models)
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name="nvidia",
|
name=name,
|
||||||
distro_type="self_hosted",
|
distro_type="self_hosted",
|
||||||
description="Use NVIDIA NIM for running LLM inference, evaluation and safety",
|
description="Use NVIDIA NIM for running LLM inference, evaluation and safety",
|
||||||
container_image=None,
|
container_image=None,
|
||||||
|
@ -86,6 +93,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"inference": [inference_provider],
|
"inference": [inference_provider],
|
||||||
"datasetio": [datasetio_provider],
|
"datasetio": [datasetio_provider],
|
||||||
"eval": [eval_provider],
|
"eval": [eval_provider],
|
||||||
|
"files": [files_provider],
|
||||||
},
|
},
|
||||||
default_models=default_models,
|
default_models=default_models,
|
||||||
default_tool_groups=default_tool_groups,
|
default_tool_groups=default_tool_groups,
|
||||||
|
@ -97,6 +105,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
safety_provider,
|
safety_provider,
|
||||||
],
|
],
|
||||||
"eval": [eval_provider],
|
"eval": [eval_provider],
|
||||||
|
"files": [files_provider],
|
||||||
},
|
},
|
||||||
default_models=[inference_model, safety_model],
|
default_models=[inference_model, safety_model],
|
||||||
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
|
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
|
||||||
|
|
|
@ -4,6 +4,7 @@ apis:
|
||||||
- agents
|
- agents
|
||||||
- datasetio
|
- datasetio
|
||||||
- eval
|
- eval
|
||||||
|
- files
|
||||||
- inference
|
- inference
|
||||||
- post_training
|
- post_training
|
||||||
- safety
|
- safety
|
||||||
|
@ -88,6 +89,14 @@ providers:
|
||||||
tool_runtime:
|
tool_runtime:
|
||||||
- provider_id: rag-runtime
|
- provider_id: rag-runtime
|
||||||
provider_type: inline::rag-runtime
|
provider_type: inline::rag-runtime
|
||||||
|
files:
|
||||||
|
- provider_id: meta-reference-files
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/nvidia/files}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/files_metadata.db
|
||||||
metadata_store:
|
metadata_store:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db
|
||||||
|
|
|
@ -4,6 +4,7 @@ apis:
|
||||||
- agents
|
- agents
|
||||||
- datasetio
|
- datasetio
|
||||||
- eval
|
- eval
|
||||||
|
- files
|
||||||
- inference
|
- inference
|
||||||
- post_training
|
- post_training
|
||||||
- safety
|
- safety
|
||||||
|
@ -77,6 +78,14 @@ providers:
|
||||||
tool_runtime:
|
tool_runtime:
|
||||||
- provider_id: rag-runtime
|
- provider_id: rag-runtime
|
||||||
provider_type: inline::rag-runtime
|
provider_type: inline::rag-runtime
|
||||||
|
files:
|
||||||
|
- provider_id: meta-reference-files
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/nvidia/files}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/files_metadata.db
|
||||||
metadata_store:
|
metadata_store:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db
|
||||||
|
|
|
@ -78,12 +78,12 @@ def get_remote_inference_providers() -> list[Provider]:
|
||||||
remote_providers = [
|
remote_providers = [
|
||||||
provider
|
provider
|
||||||
for provider in available_providers()
|
for provider in available_providers()
|
||||||
if isinstance(provider, RemoteProviderSpec) and provider.adapter.adapter_type in ENABLED_INFERENCE_PROVIDERS
|
if isinstance(provider, RemoteProviderSpec) and provider.adapter_type in ENABLED_INFERENCE_PROVIDERS
|
||||||
]
|
]
|
||||||
|
|
||||||
inference_providers = []
|
inference_providers = []
|
||||||
for provider_spec in remote_providers:
|
for provider_spec in remote_providers:
|
||||||
provider_type = provider_spec.adapter.adapter_type
|
provider_type = provider_spec.adapter_type
|
||||||
|
|
||||||
if provider_type in INFERENCE_PROVIDER_IDS:
|
if provider_type in INFERENCE_PROVIDER_IDS:
|
||||||
provider_id = INFERENCE_PROVIDER_IDS[provider_type]
|
provider_id = INFERENCE_PROVIDER_IDS[provider_type]
|
||||||
|
|
|
@ -131,6 +131,15 @@ class ProviderSpec(BaseModel):
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
pip_packages: list[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="The pip dependencies needed for this implementation",
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_data_validator: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
|
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
|
||||||
|
|
||||||
# used internally by the resolver; this is a hack for now
|
# used internally by the resolver; this is a hack for now
|
||||||
|
@ -145,45 +154,8 @@ class RoutingTable(Protocol):
|
||||||
async def get_provider_impl(self, routing_key: str) -> Any: ...
|
async def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||||
|
|
||||||
|
|
||||||
# TODO: this can now be inlined into RemoteProviderSpec
|
|
||||||
@json_schema_type
|
|
||||||
class AdapterSpec(BaseModel):
|
|
||||||
adapter_type: str = Field(
|
|
||||||
...,
|
|
||||||
description="Unique identifier for this adapter",
|
|
||||||
)
|
|
||||||
module: str = Field(
|
|
||||||
default_factory=str,
|
|
||||||
description="""
|
|
||||||
Fully-qualified name of the module to import. The module is expected to have:
|
|
||||||
|
|
||||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
pip_packages: list[str] = Field(
|
|
||||||
default_factory=list,
|
|
||||||
description="The pip dependencies needed for this implementation",
|
|
||||||
)
|
|
||||||
config_class: str = Field(
|
|
||||||
description="Fully-qualified classname of the config for this provider",
|
|
||||||
)
|
|
||||||
provider_data_validator: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
description: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="""
|
|
||||||
A description of the provider. This is used to display in the documentation.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class InlineProviderSpec(ProviderSpec):
|
class InlineProviderSpec(ProviderSpec):
|
||||||
pip_packages: list[str] = Field(
|
|
||||||
default_factory=list,
|
|
||||||
description="The pip dependencies needed for this implementation",
|
|
||||||
)
|
|
||||||
container_image: str | None = Field(
|
container_image: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="""
|
description="""
|
||||||
|
@ -191,10 +163,6 @@ The container image to use for this implementation. If one is provided, pip_pack
|
||||||
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
|
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
# module field is inherited from ProviderSpec
|
|
||||||
provider_data_validator: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
description: str | None = Field(
|
description: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="""
|
description="""
|
||||||
|
@ -223,10 +191,15 @@ class RemoteProviderConfig(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RemoteProviderSpec(ProviderSpec):
|
class RemoteProviderSpec(ProviderSpec):
|
||||||
adapter: AdapterSpec = Field(
|
adapter_type: str = Field(
|
||||||
|
...,
|
||||||
|
description="Unique identifier for this adapter",
|
||||||
|
)
|
||||||
|
|
||||||
|
description: str | None = Field(
|
||||||
|
default=None,
|
||||||
description="""
|
description="""
|
||||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
A description of the provider. This is used to display in the documentation.
|
||||||
API responses, specify the adapter here.
|
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -234,33 +207,6 @@ API responses, specify the adapter here.
|
||||||
def container_image(self) -> str | None:
|
def container_image(self) -> str | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# module field is inherited from ProviderSpec
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pip_packages(self) -> list[str]:
|
|
||||||
return self.adapter.pip_packages
|
|
||||||
|
|
||||||
@property
|
|
||||||
def provider_data_validator(self) -> str | None:
|
|
||||||
return self.adapter.provider_data_validator
|
|
||||||
|
|
||||||
|
|
||||||
def remote_provider_spec(
|
|
||||||
api: Api,
|
|
||||||
adapter: AdapterSpec,
|
|
||||||
api_dependencies: list[Api] | None = None,
|
|
||||||
optional_api_dependencies: list[Api] | None = None,
|
|
||||||
) -> RemoteProviderSpec:
|
|
||||||
return RemoteProviderSpec(
|
|
||||||
api=api,
|
|
||||||
provider_type=f"remote::{adapter.adapter_type}",
|
|
||||||
config_class=adapter.config_class,
|
|
||||||
module=adapter.module,
|
|
||||||
adapter=adapter,
|
|
||||||
api_dependencies=api_dependencies or [],
|
|
||||||
optional_api_dependencies=optional_api_dependencies or [],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HealthStatus(StrEnum):
|
class HealthStatus(StrEnum):
|
||||||
OK = "OK"
|
OK = "OK"
|
||||||
|
|
|
@ -44,7 +44,7 @@ class LocalfsFilesImpl(Files):
|
||||||
storage_path.mkdir(parents=True, exist_ok=True)
|
storage_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Initialize SQL store for metadata
|
# Initialize SQL store for metadata
|
||||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store))
|
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
|
||||||
await self.sql_store.create_table(
|
await self.sql_store.create_table(
|
||||||
"openai_files",
|
"openai_files",
|
||||||
{
|
{
|
||||||
|
@ -74,7 +74,7 @@ class LocalfsFilesImpl(Files):
|
||||||
if not self.sql_store:
|
if not self.sql_store:
|
||||||
raise RuntimeError("Files provider not initialized")
|
raise RuntimeError("Files provider not initialized")
|
||||||
|
|
||||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||||
if not row:
|
if not row:
|
||||||
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
|
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
|
||||||
|
|
||||||
|
@ -150,7 +150,6 @@ class LocalfsFilesImpl(Files):
|
||||||
|
|
||||||
paginated_result = await self.sql_store.fetch_all(
|
paginated_result = await self.sql_store.fetch_all(
|
||||||
table="openai_files",
|
table="openai_files",
|
||||||
policy=self.policy,
|
|
||||||
where=where_conditions if where_conditions else None,
|
where=where_conditions if where_conditions else None,
|
||||||
order_by=[("created_at", order.value)],
|
order_by=[("created_at", order.value)],
|
||||||
cursor=("id", after) if after else None,
|
cursor=("id", after) if after else None,
|
||||||
|
|
|
@ -6,11 +6,10 @@
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
AdapterSpec,
|
|
||||||
Api,
|
Api,
|
||||||
InlineProviderSpec,
|
InlineProviderSpec,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
remote_provider_spec,
|
RemoteProviderSpec,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,28 +24,26 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
api_dependencies=[],
|
api_dependencies=[],
|
||||||
description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.",
|
description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.datasetio,
|
api=Api.datasetio,
|
||||||
adapter=AdapterSpec(
|
adapter_type="huggingface",
|
||||||
adapter_type="huggingface",
|
provider_type="remote::huggingface",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"datasets>=4.0.0",
|
"datasets>=4.0.0",
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.remote.datasetio.huggingface",
|
module="llama_stack.providers.remote.datasetio.huggingface",
|
||||||
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
|
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
|
||||||
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
|
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.datasetio,
|
api=Api.datasetio,
|
||||||
adapter=AdapterSpec(
|
adapter_type="nvidia",
|
||||||
adapter_type="nvidia",
|
provider_type="remote::nvidia",
|
||||||
pip_packages=[
|
module="llama_stack.providers.remote.datasetio.nvidia",
|
||||||
"datasets>=4.0.0",
|
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
|
||||||
],
|
pip_packages=[
|
||||||
module="llama_stack.providers.remote.datasetio.nvidia",
|
"datasets>=4.0.0",
|
||||||
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
|
],
|
||||||
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
|
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||||
|
|
||||||
|
|
||||||
def available_providers() -> list[ProviderSpec]:
|
def available_providers() -> list[ProviderSpec]:
|
||||||
|
@ -25,17 +25,16 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
],
|
],
|
||||||
description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.",
|
description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.eval,
|
api=Api.eval,
|
||||||
adapter=AdapterSpec(
|
adapter_type="nvidia",
|
||||||
adapter_type="nvidia",
|
pip_packages=[
|
||||||
pip_packages=[
|
"requests",
|
||||||
"requests",
|
],
|
||||||
],
|
provider_type="remote::nvidia",
|
||||||
module="llama_stack.providers.remote.eval.nvidia",
|
module="llama_stack.providers.remote.eval.nvidia",
|
||||||
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
|
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
|
||||||
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
|
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
|
||||||
),
|
|
||||||
api_dependencies=[
|
api_dependencies=[
|
||||||
Api.datasetio,
|
Api.datasetio,
|
||||||
Api.datasets,
|
Api.datasets,
|
||||||
|
|
|
@ -4,13 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||||
AdapterSpec,
|
|
||||||
Api,
|
|
||||||
InlineProviderSpec,
|
|
||||||
ProviderSpec,
|
|
||||||
remote_provider_spec,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
|
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,14 +19,13 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
|
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
|
||||||
description="Local filesystem-based file storage provider for managing files and documents locally.",
|
description="Local filesystem-based file storage provider for managing files and documents locally.",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.files,
|
api=Api.files,
|
||||||
adapter=AdapterSpec(
|
provider_type="remote::s3",
|
||||||
adapter_type="s3",
|
adapter_type="s3",
|
||||||
pip_packages=["boto3"] + sql_store_pip_packages,
|
pip_packages=["boto3"] + sql_store_pip_packages,
|
||||||
module="llama_stack.providers.remote.files.s3",
|
module="llama_stack.providers.remote.files.s3",
|
||||||
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
|
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
|
||||||
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
|
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -6,11 +6,10 @@
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
AdapterSpec,
|
|
||||||
Api,
|
Api,
|
||||||
InlineProviderSpec,
|
InlineProviderSpec,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
remote_provider_spec,
|
RemoteProviderSpec,
|
||||||
)
|
)
|
||||||
|
|
||||||
META_REFERENCE_DEPS = [
|
META_REFERENCE_DEPS = [
|
||||||
|
@ -49,177 +48,167 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
||||||
description="Sentence Transformers inference provider for text embeddings and similarity search.",
|
description="Sentence Transformers inference provider for text embeddings and similarity search.",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter_type="cerebras",
|
||||||
adapter_type="cerebras",
|
provider_type="remote::cerebras",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"cerebras_cloud_sdk",
|
"cerebras_cloud_sdk",
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.remote.inference.cerebras",
|
module="llama_stack.providers.remote.inference.cerebras",
|
||||||
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
||||||
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
|
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter_type="ollama",
|
||||||
adapter_type="ollama",
|
provider_type="remote::ollama",
|
||||||
pip_packages=["ollama", "aiohttp", "h11>=0.16.0"],
|
pip_packages=["ollama", "aiohttp", "h11>=0.16.0"],
|
||||||
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
|
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
|
||||||
module="llama_stack.providers.remote.inference.ollama",
|
module="llama_stack.providers.remote.inference.ollama",
|
||||||
description="Ollama inference provider for running local models through the Ollama runtime.",
|
description="Ollama inference provider for running local models through the Ollama runtime.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter_type="vllm",
|
||||||
adapter_type="vllm",
|
provider_type="remote::vllm",
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.remote.inference.vllm",
|
module="llama_stack.providers.remote.inference.vllm",
|
||||||
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
|
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
|
||||||
description="Remote vLLM inference provider for connecting to vLLM servers.",
|
description="Remote vLLM inference provider for connecting to vLLM servers.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter_type="tgi",
|
||||||
adapter_type="tgi",
|
provider_type="remote::tgi",
|
||||||
pip_packages=["huggingface_hub", "aiohttp"],
|
pip_packages=["huggingface_hub", "aiohttp"],
|
||||||
module="llama_stack.providers.remote.inference.tgi",
|
module="llama_stack.providers.remote.inference.tgi",
|
||||||
config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig",
|
config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig",
|
||||||
description="Text Generation Inference (TGI) provider for HuggingFace model serving.",
|
description="Text Generation Inference (TGI) provider for HuggingFace model serving.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter_type="hf::serverless",
|
||||||
adapter_type="hf::serverless",
|
provider_type="remote::hf::serverless",
|
||||||
pip_packages=["huggingface_hub", "aiohttp"],
|
pip_packages=["huggingface_hub", "aiohttp"],
|
||||||
module="llama_stack.providers.remote.inference.tgi",
|
module="llama_stack.providers.remote.inference.tgi",
|
||||||
config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig",
|
config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig",
|
||||||
description="HuggingFace Inference API serverless provider for on-demand model inference.",
|
description="HuggingFace Inference API serverless provider for on-demand model inference.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
provider_type="remote::hf::endpoint",
|
||||||
adapter_type="hf::endpoint",
|
adapter_type="hf::endpoint",
|
||||||
pip_packages=["huggingface_hub", "aiohttp"],
|
pip_packages=["huggingface_hub", "aiohttp"],
|
||||||
module="llama_stack.providers.remote.inference.tgi",
|
module="llama_stack.providers.remote.inference.tgi",
|
||||||
config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig",
|
config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig",
|
||||||
description="HuggingFace Inference Endpoints provider for dedicated model serving.",
|
description="HuggingFace Inference Endpoints provider for dedicated model serving.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter_type="fireworks",
|
||||||
adapter_type="fireworks",
|
provider_type="remote::fireworks",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"fireworks-ai<=0.17.16",
|
"fireworks-ai<=0.17.16",
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.remote.inference.fireworks",
|
module="llama_stack.providers.remote.inference.fireworks",
|
||||||
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
|
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
|
||||||
description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.",
|
description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter_type="together",
|
||||||
adapter_type="together",
|
provider_type="remote::together",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"together",
|
"together",
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.remote.inference.together",
|
module="llama_stack.providers.remote.inference.together",
|
||||||
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
|
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
||||||
description="Together AI inference provider for open-source models and collaborative AI development.",
|
description="Together AI inference provider for open-source models and collaborative AI development.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter_type="bedrock",
|
||||||
adapter_type="bedrock",
|
provider_type="remote::bedrock",
|
||||||
pip_packages=["boto3"],
|
pip_packages=["boto3"],
|
||||||
module="llama_stack.providers.remote.inference.bedrock",
|
module="llama_stack.providers.remote.inference.bedrock",
|
||||||
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
|
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
|
||||||
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
|
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter_type="databricks",
|
||||||
adapter_type="databricks",
|
provider_type="remote::databricks",
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.remote.inference.databricks",
|
module="llama_stack.providers.remote.inference.databricks",
|
||||||
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
||||||
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
|
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter_type="nvidia",
|
||||||
adapter_type="nvidia",
|
provider_type="remote::nvidia",
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.remote.inference.nvidia",
|
module="llama_stack.providers.remote.inference.nvidia",
|
||||||
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
||||||
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
|
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter_type="runpod",
|
||||||
adapter_type="runpod",
|
provider_type="remote::runpod",
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.remote.inference.runpod",
|
module="llama_stack.providers.remote.inference.runpod",
|
||||||
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
|
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
|
||||||
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
|
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter_type="openai",
|
||||||
adapter_type="openai",
|
provider_type="remote::openai",
|
||||||
pip_packages=["litellm"],
|
pip_packages=["litellm"],
|
||||||
module="llama_stack.providers.remote.inference.openai",
|
module="llama_stack.providers.remote.inference.openai",
|
||||||
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
|
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
|
||||||
description="OpenAI inference provider for accessing GPT models and other OpenAI services.",
|
description="OpenAI inference provider for accessing GPT models and other OpenAI services.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter_type="anthropic",
|
||||||
adapter_type="anthropic",
|
provider_type="remote::anthropic",
|
||||||
pip_packages=["litellm"],
|
pip_packages=["litellm"],
|
||||||
module="llama_stack.providers.remote.inference.anthropic",
|
module="llama_stack.providers.remote.inference.anthropic",
|
||||||
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
|
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
|
||||||
description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.",
|
description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter_type="gemini",
|
||||||
adapter_type="gemini",
|
provider_type="remote::gemini",
|
||||||
pip_packages=["litellm"],
|
pip_packages=[
|
||||||
module="llama_stack.providers.remote.inference.gemini",
|
"litellm",
|
||||||
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
|
],
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
|
module="llama_stack.providers.remote.inference.gemini",
|
||||||
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
|
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
|
||||||
),
|
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
|
||||||
|
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter_type="vertexai",
|
||||||
adapter_type="vertexai",
|
provider_type="remote::vertexai",
|
||||||
pip_packages=["litellm", "google-cloud-aiplatform"],
|
pip_packages=[
|
||||||
module="llama_stack.providers.remote.inference.vertexai",
|
"litellm",
|
||||||
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
|
"google-cloud-aiplatform",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
|
],
|
||||||
description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages:
|
module="llama_stack.providers.remote.inference.vertexai",
|
||||||
|
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
|
||||||
|
description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages:
|
||||||
|
|
||||||
• Enterprise-grade security: Uses Google Cloud's security controls and IAM
|
• Enterprise-grade security: Uses Google Cloud's security controls and IAM
|
||||||
• Better integration: Seamless integration with other Google Cloud services
|
• Better integration: Seamless integration with other Google Cloud services
|
||||||
|
@ -239,76 +228,73 @@ Available Models:
|
||||||
- vertex_ai/gemini-2.0-flash
|
- vertex_ai/gemini-2.0-flash
|
||||||
- vertex_ai/gemini-2.5-flash
|
- vertex_ai/gemini-2.5-flash
|
||||||
- vertex_ai/gemini-2.5-pro""",
|
- vertex_ai/gemini-2.5-pro""",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter_type="groq",
|
||||||
adapter_type="groq",
|
provider_type="remote::groq",
|
||||||
pip_packages=["litellm"],
|
pip_packages=[
|
||||||
module="llama_stack.providers.remote.inference.groq",
|
"litellm",
|
||||||
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
],
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
module="llama_stack.providers.remote.inference.groq",
|
||||||
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
|
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
||||||
),
|
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||||
|
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter_type="llama-openai-compat",
|
||||||
adapter_type="llama-openai-compat",
|
provider_type="remote::llama-openai-compat",
|
||||||
pip_packages=["litellm"],
|
pip_packages=["litellm"],
|
||||||
module="llama_stack.providers.remote.inference.llama_openai_compat",
|
module="llama_stack.providers.remote.inference.llama_openai_compat",
|
||||||
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
|
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
|
||||||
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
|
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter_type="sambanova",
|
||||||
adapter_type="sambanova",
|
provider_type="remote::sambanova",
|
||||||
pip_packages=["litellm"],
|
pip_packages=[
|
||||||
module="llama_stack.providers.remote.inference.sambanova",
|
"litellm",
|
||||||
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
|
],
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
|
module="llama_stack.providers.remote.inference.sambanova",
|
||||||
description="SambaNova inference provider for running models on SambaNova's dataflow architecture.",
|
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
|
||||||
),
|
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
|
||||||
|
description="SambaNova inference provider for running models on SambaNova's dataflow architecture.",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter_type="passthrough",
|
||||||
adapter_type="passthrough",
|
provider_type="remote::passthrough",
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.remote.inference.passthrough",
|
module="llama_stack.providers.remote.inference.passthrough",
|
||||||
config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig",
|
config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
|
||||||
description="Passthrough inference provider for connecting to any external inference service not directly supported.",
|
description="Passthrough inference provider for connecting to any external inference service not directly supported.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter_type="watsonx",
|
||||||
adapter_type="watsonx",
|
provider_type="remote::watsonx",
|
||||||
pip_packages=["ibm_watsonx_ai"],
|
pip_packages=["ibm_watsonx_ai"],
|
||||||
module="llama_stack.providers.remote.inference.watsonx",
|
module="llama_stack.providers.remote.inference.watsonx",
|
||||||
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||||
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
|
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
provider_type="remote::azure",
|
||||||
adapter_type="azure",
|
adapter_type="azure",
|
||||||
pip_packages=["litellm"],
|
pip_packages=["litellm"],
|
||||||
module="llama_stack.providers.remote.inference.azure",
|
module="llama_stack.providers.remote.inference.azure",
|
||||||
config_class="llama_stack.providers.remote.inference.azure.AzureConfig",
|
config_class="llama_stack.providers.remote.inference.azure.AzureConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator",
|
||||||
description="""
|
description="""
|
||||||
Azure OpenAI inference provider for accessing GPT models and other Azure services.
|
Azure OpenAI inference provider for accessing GPT models and other Azure services.
|
||||||
Provider documentation
|
Provider documentation
|
||||||
https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
|
https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
|
||||||
""",
|
""",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
|
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||||
|
|
||||||
# We provide two versions of these providers so that distributions can package the appropriate version of torch.
|
# We provide two versions of these providers so that distributions can package the appropriate version of torch.
|
||||||
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
|
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
|
||||||
|
@ -57,14 +57,13 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
],
|
],
|
||||||
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
|
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.post_training,
|
api=Api.post_training,
|
||||||
adapter=AdapterSpec(
|
adapter_type="nvidia",
|
||||||
adapter_type="nvidia",
|
provider_type="remote::nvidia",
|
||||||
pip_packages=["requests", "aiohttp"],
|
pip_packages=["requests", "aiohttp"],
|
||||||
module="llama_stack.providers.remote.post_training.nvidia",
|
module="llama_stack.providers.remote.post_training.nvidia",
|
||||||
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
|
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
|
||||||
description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.",
|
description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -6,11 +6,10 @@
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
AdapterSpec,
|
|
||||||
Api,
|
Api,
|
||||||
InlineProviderSpec,
|
InlineProviderSpec,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
remote_provider_spec,
|
RemoteProviderSpec,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,35 +47,32 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
|
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
|
||||||
description="Code Scanner safety provider for detecting security vulnerabilities and unsafe code patterns.",
|
description="Code Scanner safety provider for detecting security vulnerabilities and unsafe code patterns.",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
adapter=AdapterSpec(
|
adapter_type="bedrock",
|
||||||
adapter_type="bedrock",
|
provider_type="remote::bedrock",
|
||||||
pip_packages=["boto3"],
|
pip_packages=["boto3"],
|
||||||
module="llama_stack.providers.remote.safety.bedrock",
|
module="llama_stack.providers.remote.safety.bedrock",
|
||||||
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
|
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
|
||||||
description="AWS Bedrock safety provider for content moderation using AWS's safety services.",
|
description="AWS Bedrock safety provider for content moderation using AWS's safety services.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
adapter=AdapterSpec(
|
adapter_type="nvidia",
|
||||||
adapter_type="nvidia",
|
provider_type="remote::nvidia",
|
||||||
pip_packages=["requests"],
|
pip_packages=["requests"],
|
||||||
module="llama_stack.providers.remote.safety.nvidia",
|
module="llama_stack.providers.remote.safety.nvidia",
|
||||||
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
|
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
|
||||||
description="NVIDIA's safety provider for content moderation and safety filtering.",
|
description="NVIDIA's safety provider for content moderation and safety filtering.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
adapter=AdapterSpec(
|
adapter_type="sambanova",
|
||||||
adapter_type="sambanova",
|
provider_type="remote::sambanova",
|
||||||
pip_packages=["litellm", "requests"],
|
pip_packages=["litellm", "requests"],
|
||||||
module="llama_stack.providers.remote.safety.sambanova",
|
module="llama_stack.providers.remote.safety.sambanova",
|
||||||
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
|
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
|
||||||
description="SambaNova's safety provider for content moderation and safety filtering.",
|
description="SambaNova's safety provider for content moderation and safety filtering.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -6,11 +6,10 @@
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
AdapterSpec,
|
|
||||||
Api,
|
Api,
|
||||||
InlineProviderSpec,
|
InlineProviderSpec,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
remote_provider_spec,
|
RemoteProviderSpec,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,59 +34,54 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
api_dependencies=[Api.vector_io, Api.inference, Api.files],
|
api_dependencies=[Api.vector_io, Api.inference, Api.files],
|
||||||
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
|
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
adapter=AdapterSpec(
|
adapter_type="brave-search",
|
||||||
adapter_type="brave-search",
|
provider_type="remote::brave-search",
|
||||||
module="llama_stack.providers.remote.tool_runtime.brave_search",
|
module="llama_stack.providers.remote.tool_runtime.brave_search",
|
||||||
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
||||||
pip_packages=["requests"],
|
pip_packages=["requests"],
|
||||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
||||||
description="Brave Search tool for web search capabilities with privacy-focused results.",
|
description="Brave Search tool for web search capabilities with privacy-focused results.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
adapter=AdapterSpec(
|
adapter_type="bing-search",
|
||||||
adapter_type="bing-search",
|
provider_type="remote::bing-search",
|
||||||
module="llama_stack.providers.remote.tool_runtime.bing_search",
|
module="llama_stack.providers.remote.tool_runtime.bing_search",
|
||||||
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
|
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
|
||||||
pip_packages=["requests"],
|
pip_packages=["requests"],
|
||||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
|
||||||
description="Bing Search tool for web search capabilities using Microsoft's search engine.",
|
description="Bing Search tool for web search capabilities using Microsoft's search engine.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
adapter=AdapterSpec(
|
adapter_type="tavily-search",
|
||||||
adapter_type="tavily-search",
|
provider_type="remote::tavily-search",
|
||||||
module="llama_stack.providers.remote.tool_runtime.tavily_search",
|
module="llama_stack.providers.remote.tool_runtime.tavily_search",
|
||||||
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
|
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
|
||||||
pip_packages=["requests"],
|
pip_packages=["requests"],
|
||||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
|
||||||
description="Tavily Search tool for AI-optimized web search with structured results.",
|
description="Tavily Search tool for AI-optimized web search with structured results.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
adapter=AdapterSpec(
|
adapter_type="wolfram-alpha",
|
||||||
adapter_type="wolfram-alpha",
|
provider_type="remote::wolfram-alpha",
|
||||||
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
|
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
|
||||||
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
|
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
|
||||||
pip_packages=["requests"],
|
pip_packages=["requests"],
|
||||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
|
||||||
description="Wolfram Alpha tool for computational knowledge and mathematical calculations.",
|
description="Wolfram Alpha tool for computational knowledge and mathematical calculations.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
adapter=AdapterSpec(
|
adapter_type="model-context-protocol",
|
||||||
adapter_type="model-context-protocol",
|
provider_type="remote::model-context-protocol",
|
||||||
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
|
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
|
||||||
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
|
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
|
||||||
pip_packages=["mcp>=1.8.1"],
|
pip_packages=["mcp>=1.8.1"],
|
||||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
|
||||||
description="Model Context Protocol (MCP) tool for standardized tool calling and context management.",
|
description="Model Context Protocol (MCP) tool for standardized tool calling and context management.",
|
||||||
),
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -6,11 +6,10 @@
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
AdapterSpec,
|
|
||||||
Api,
|
Api,
|
||||||
InlineProviderSpec,
|
InlineProviderSpec,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
remote_provider_spec,
|
RemoteProviderSpec,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -300,14 +299,16 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
|
||||||
Please refer to the sqlite-vec provider documentation.
|
Please refer to the sqlite-vec provider documentation.
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
Api.vector_io,
|
api=Api.vector_io,
|
||||||
AdapterSpec(
|
adapter_type="chromadb",
|
||||||
adapter_type="chromadb",
|
provider_type="remote::chromadb",
|
||||||
pip_packages=["chromadb-client"],
|
pip_packages=["chromadb-client"],
|
||||||
module="llama_stack.providers.remote.vector_io.chroma",
|
module="llama_stack.providers.remote.vector_io.chroma",
|
||||||
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
||||||
description="""
|
api_dependencies=[Api.inference],
|
||||||
|
optional_api_dependencies=[Api.files],
|
||||||
|
description="""
|
||||||
[Chroma](https://www.trychroma.com/) is an inline and remote vector
|
[Chroma](https://www.trychroma.com/) is an inline and remote vector
|
||||||
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
|
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
|
||||||
That means you're not limited to storing vectors in memory or in a separate service.
|
That means you're not limited to storing vectors in memory or in a separate service.
|
||||||
|
@ -340,9 +341,6 @@ pip install chromadb
|
||||||
## Documentation
|
## Documentation
|
||||||
See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general.
|
See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general.
|
||||||
""",
|
""",
|
||||||
),
|
|
||||||
api_dependencies=[Api.inference],
|
|
||||||
optional_api_dependencies=[Api.files],
|
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
|
@ -387,14 +385,16 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
||||||
|
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
Api.vector_io,
|
api=Api.vector_io,
|
||||||
AdapterSpec(
|
adapter_type="pgvector",
|
||||||
adapter_type="pgvector",
|
provider_type="remote::pgvector",
|
||||||
pip_packages=["psycopg2-binary"],
|
pip_packages=["psycopg2-binary"],
|
||||||
module="llama_stack.providers.remote.vector_io.pgvector",
|
module="llama_stack.providers.remote.vector_io.pgvector",
|
||||||
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
||||||
description="""
|
api_dependencies=[Api.inference],
|
||||||
|
optional_api_dependencies=[Api.files],
|
||||||
|
description="""
|
||||||
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
|
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
|
||||||
allows you to store and query vectors directly in memory.
|
allows you to store and query vectors directly in memory.
|
||||||
That means you'll get fast and efficient vector retrieval.
|
That means you'll get fast and efficient vector retrieval.
|
||||||
|
@ -495,19 +495,18 @@ docker pull pgvector/pgvector:pg17
|
||||||
## Documentation
|
## Documentation
|
||||||
See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general.
|
See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general.
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
|
RemoteProviderSpec(
|
||||||
|
api=Api.vector_io,
|
||||||
|
adapter_type="weaviate",
|
||||||
|
provider_type="remote::weaviate",
|
||||||
|
pip_packages=["weaviate-client"],
|
||||||
|
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||||
|
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
optional_api_dependencies=[Api.files],
|
optional_api_dependencies=[Api.files],
|
||||||
),
|
description="""
|
||||||
remote_provider_spec(
|
|
||||||
Api.vector_io,
|
|
||||||
AdapterSpec(
|
|
||||||
adapter_type="weaviate",
|
|
||||||
pip_packages=["weaviate-client"],
|
|
||||||
module="llama_stack.providers.remote.vector_io.weaviate",
|
|
||||||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
|
||||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
|
||||||
description="""
|
|
||||||
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
|
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
|
||||||
It allows you to store and query vectors directly within a Weaviate database.
|
It allows you to store and query vectors directly within a Weaviate database.
|
||||||
That means you're not limited to storing vectors in memory or in a separate service.
|
That means you're not limited to storing vectors in memory or in a separate service.
|
||||||
|
@ -538,9 +537,6 @@ To install Weaviate see the [Weaviate quickstart documentation](https://weaviate
|
||||||
## Documentation
|
## Documentation
|
||||||
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.
|
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.
|
||||||
""",
|
""",
|
||||||
),
|
|
||||||
api_dependencies=[Api.inference],
|
|
||||||
optional_api_dependencies=[Api.files],
|
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
|
@ -594,28 +590,29 @@ docker pull qdrant/qdrant
|
||||||
See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general.
|
See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general.
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
Api.vector_io,
|
api=Api.vector_io,
|
||||||
AdapterSpec(
|
adapter_type="qdrant",
|
||||||
adapter_type="qdrant",
|
provider_type="remote::qdrant",
|
||||||
pip_packages=["qdrant-client"],
|
pip_packages=["qdrant-client"],
|
||||||
module="llama_stack.providers.remote.vector_io.qdrant",
|
module="llama_stack.providers.remote.vector_io.qdrant",
|
||||||
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
||||||
description="""
|
|
||||||
Please refer to the inline provider documentation.
|
|
||||||
""",
|
|
||||||
),
|
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
optional_api_dependencies=[Api.files],
|
optional_api_dependencies=[Api.files],
|
||||||
|
description="""
|
||||||
|
Please refer to the inline provider documentation.
|
||||||
|
""",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
RemoteProviderSpec(
|
||||||
Api.vector_io,
|
api=Api.vector_io,
|
||||||
AdapterSpec(
|
adapter_type="milvus",
|
||||||
adapter_type="milvus",
|
provider_type="remote::milvus",
|
||||||
pip_packages=["pymilvus>=2.4.10"],
|
pip_packages=["pymilvus>=2.4.10"],
|
||||||
module="llama_stack.providers.remote.vector_io.milvus",
|
module="llama_stack.providers.remote.vector_io.milvus",
|
||||||
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
|
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
|
||||||
description="""
|
api_dependencies=[Api.inference],
|
||||||
|
optional_api_dependencies=[Api.files],
|
||||||
|
description="""
|
||||||
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
|
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
|
||||||
allows you to store and query vectors directly within a Milvus database.
|
allows you to store and query vectors directly within a Milvus database.
|
||||||
That means you're not limited to storing vectors in memory or in a separate service.
|
That means you're not limited to storing vectors in memory or in a separate service.
|
||||||
|
@ -636,7 +633,13 @@ To use Milvus in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
You can install Milvus using pymilvus:
|
If you want to use inline Milvus, you can install:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install pymilvus[milvus-lite]
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to use remote Milvus, you can install:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install pymilvus
|
pip install pymilvus
|
||||||
|
@ -806,14 +809,11 @@ See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for m
|
||||||
|
|
||||||
For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md).
|
For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md).
|
||||||
""",
|
""",
|
||||||
),
|
|
||||||
api_dependencies=[Api.inference],
|
|
||||||
optional_api_dependencies=[Api.files],
|
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.vector_io,
|
api=Api.vector_io,
|
||||||
provider_type="inline::milvus",
|
provider_type="inline::milvus",
|
||||||
pip_packages=["pymilvus>=2.4.10"],
|
pip_packages=["pymilvus[milvus-lite]>=2.4.10"],
|
||||||
module="llama_stack.providers.inline.vector_io.milvus",
|
module="llama_stack.providers.inline.vector_io.milvus",
|
||||||
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
|
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
|
|
|
@ -137,7 +137,7 @@ class S3FilesImpl(Files):
|
||||||
where: dict[str, str | dict] = {"id": file_id}
|
where: dict[str, str | dict] = {"id": file_id}
|
||||||
if not return_expired:
|
if not return_expired:
|
||||||
where["expires_at"] = {">": self._now()}
|
where["expires_at"] = {">": self._now()}
|
||||||
if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)):
|
if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
|
||||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||||
return row
|
return row
|
||||||
|
|
||||||
|
@ -164,7 +164,7 @@ class S3FilesImpl(Files):
|
||||||
self._client = _create_s3_client(self._config)
|
self._client = _create_s3_client(self._config)
|
||||||
await _create_bucket_if_not_exists(self._client, self._config)
|
await _create_bucket_if_not_exists(self._client, self._config)
|
||||||
|
|
||||||
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store))
|
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
|
||||||
await self._sql_store.create_table(
|
await self._sql_store.create_table(
|
||||||
"openai_files",
|
"openai_files",
|
||||||
{
|
{
|
||||||
|
@ -268,7 +268,6 @@ class S3FilesImpl(Files):
|
||||||
|
|
||||||
paginated_result = await self.sql_store.fetch_all(
|
paginated_result = await self.sql_store.fetch_all(
|
||||||
table="openai_files",
|
table="openai_files",
|
||||||
policy=self.policy,
|
|
||||||
where=where_conditions,
|
where=where_conditions,
|
||||||
order_by=[("created_at", order.value)],
|
order_by=[("created_at", order.value)],
|
||||||
cursor=("id", after) if after else None,
|
cursor=("id", after) if after else None,
|
||||||
|
|
|
@ -4,15 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from .config import AnthropicConfig
|
from .config import AnthropicConfig
|
||||||
|
|
||||||
|
|
||||||
class AnthropicProviderDataValidator(BaseModel):
|
|
||||||
anthropic_api_key: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: AnthropicConfig, _deps):
|
async def get_adapter_impl(config: AnthropicConfig, _deps):
|
||||||
from .anthropic import AnthropicInferenceAdapter
|
from .anthropic import AnthropicInferenceAdapter
|
||||||
|
|
||||||
|
|
|
@ -4,11 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fireworks.client import Fireworks
|
from fireworks.client import Fireworks
|
||||||
from openai import AsyncOpenAI
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
@ -24,12 +22,6 @@ from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
OpenAIChatCompletion,
|
|
||||||
OpenAIChatCompletionChunk,
|
|
||||||
OpenAICompletion,
|
|
||||||
OpenAIEmbeddingsResponse,
|
|
||||||
OpenAIMessageParam,
|
|
||||||
OpenAIResponseFormatParam,
|
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
@ -45,15 +37,14 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
OpenAIChatCompletionToLlamaStackMixin,
|
|
||||||
convert_message_to_openai_dict,
|
convert_message_to_openai_dict,
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
prepare_openai_completion_params,
|
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
process_completion_response,
|
process_completion_response,
|
||||||
process_completion_stream_response,
|
process_completion_stream_response,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
@ -68,7 +59,7 @@ from .models import MODEL_ENTRIES
|
||||||
logger = get_logger(name=__name__, category="inference::fireworks")
|
logger = get_logger(name=__name__, category="inference::fireworks")
|
||||||
|
|
||||||
|
|
||||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||||
def __init__(self, config: FireworksImplConfig) -> None:
|
def __init__(self, config: FireworksImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
|
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -79,7 +70,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def get_api_key(self) -> str:
|
||||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||||
if config_api_key:
|
if config_api_key:
|
||||||
return config_api_key
|
return config_api_key
|
||||||
|
@ -91,15 +82,18 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
)
|
)
|
||||||
return provider_data.fireworks_api_key
|
return provider_data.fireworks_api_key
|
||||||
|
|
||||||
def _get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
return "https://api.fireworks.ai/inference/v1"
|
return "https://api.fireworks.ai/inference/v1"
|
||||||
|
|
||||||
def _get_client(self) -> Fireworks:
|
def _get_client(self) -> Fireworks:
|
||||||
fireworks_api_key = self._get_api_key()
|
fireworks_api_key = self.get_api_key()
|
||||||
return Fireworks(api_key=fireworks_api_key)
|
return Fireworks(api_key=fireworks_api_key)
|
||||||
|
|
||||||
def _get_openai_client(self) -> AsyncOpenAI:
|
def _preprocess_prompt_for_fireworks(self, prompt: str) -> str:
|
||||||
return AsyncOpenAI(base_url=self._get_base_url(), api_key=self._get_api_key())
|
"""Remove BOS token as Fireworks automatically prepends it"""
|
||||||
|
if prompt.startswith("<|begin_of_text|>"):
|
||||||
|
return prompt[len("<|begin_of_text|>") :]
|
||||||
|
return prompt
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
|
@ -285,153 +279,3 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
|
|
||||||
embeddings = [data.embedding for data in response.data]
|
embeddings = [data.embedding for data in response.data]
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
async def openai_embeddings(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
input: str | list[str],
|
|
||||||
encoding_format: str | None = "float",
|
|
||||||
dimensions: int | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIEmbeddingsResponse:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
async def openai_completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
prompt: str | list[str] | list[int] | list[list[int]],
|
|
||||||
best_of: int | None = None,
|
|
||||||
echo: bool | None = None,
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
guided_choice: list[str] | None = None,
|
|
||||||
prompt_logprobs: int | None = None,
|
|
||||||
suffix: str | None = None,
|
|
||||||
) -> OpenAICompletion:
|
|
||||||
model_obj = await self.model_store.get_model(model)
|
|
||||||
|
|
||||||
# Fireworks always prepends with BOS
|
|
||||||
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
|
|
||||||
prompt = prompt[len("<|begin_of_text|>") :]
|
|
||||||
|
|
||||||
params = await prepare_openai_completion_params(
|
|
||||||
model=model_obj.provider_resource_id,
|
|
||||||
prompt=prompt,
|
|
||||||
best_of=best_of,
|
|
||||||
echo=echo,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self._get_openai_client().completions.create(**params)
|
|
||||||
|
|
||||||
async def openai_chat_completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: list[OpenAIMessageParam],
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
function_call: str | dict[str, Any] | None = None,
|
|
||||||
functions: list[dict[str, Any]] | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
parallel_tool_calls: bool | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
response_format: OpenAIResponseFormatParam | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
|
||||||
model_obj = await self.model_store.get_model(model)
|
|
||||||
|
|
||||||
# Divert Llama Models through Llama Stack inference APIs because
|
|
||||||
# Fireworks chat completions OpenAI-compatible API does not support
|
|
||||||
# tool calls properly.
|
|
||||||
llama_model = self.get_llama_model(model_obj.provider_resource_id)
|
|
||||||
|
|
||||||
if llama_model:
|
|
||||||
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
|
|
||||||
self,
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
function_call=function_call,
|
|
||||||
functions=functions,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_completion_tokens=max_completion_tokens,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
response_format=response_format,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
tools=tools,
|
|
||||||
top_logprobs=top_logprobs,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
params = await prepare_openai_completion_params(
|
|
||||||
messages=messages,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
function_call=function_call,
|
|
||||||
functions=functions,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_completion_tokens=max_completion_tokens,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
response_format=response_format,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
tools=tools,
|
|
||||||
top_logprobs=top_logprobs,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"fireworks params: {params}")
|
|
||||||
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)
|
|
||||||
|
|
|
@ -4,15 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from .config import GeminiConfig
|
from .config import GeminiConfig
|
||||||
|
|
||||||
|
|
||||||
class GeminiProviderDataValidator(BaseModel):
|
|
||||||
gemini_api_key: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: GeminiConfig, _deps):
|
async def get_adapter_impl(config: GeminiConfig, _deps):
|
||||||
from .gemini import GeminiInferenceAdapter
|
from .gemini import GeminiInferenceAdapter
|
||||||
|
|
||||||
|
|
|
@ -7,12 +7,10 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import uuid
|
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from ollama import AsyncClient # type: ignore[attr-defined]
|
from ollama import AsyncClient as AsyncOllamaClient
|
||||||
from openai import AsyncOpenAI
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
ImageContentItem,
|
ImageContentItem,
|
||||||
|
@ -37,9 +35,6 @@ from llama_stack.apis.inference import (
|
||||||
Message,
|
Message,
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
OpenAICompletion,
|
|
||||||
OpenAIEmbeddingsResponse,
|
|
||||||
OpenAIEmbeddingUsage,
|
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
@ -64,15 +59,14 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
OpenAICompatCompletionChoice,
|
OpenAICompatCompletionChoice,
|
||||||
OpenAICompatCompletionResponse,
|
OpenAICompatCompletionResponse,
|
||||||
b64_encode_openai_embeddings_response,
|
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
prepare_openai_completion_params,
|
prepare_openai_completion_params,
|
||||||
prepare_openai_embeddings_params,
|
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
process_completion_response,
|
process_completion_response,
|
||||||
process_completion_stream_response,
|
process_completion_stream_response,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
@ -89,6 +83,7 @@ logger = get_logger(name=__name__, category="inference::ollama")
|
||||||
|
|
||||||
|
|
||||||
class OllamaInferenceAdapter(
|
class OllamaInferenceAdapter(
|
||||||
|
OpenAIMixin,
|
||||||
InferenceProvider,
|
InferenceProvider,
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
):
|
):
|
||||||
|
@ -98,23 +93,21 @@ class OllamaInferenceAdapter(
|
||||||
def __init__(self, config: OllamaImplConfig) -> None:
|
def __init__(self, config: OllamaImplConfig) -> None:
|
||||||
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
|
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
|
||||||
self.config = config
|
self.config = config
|
||||||
self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {}
|
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
|
||||||
self._openai_client = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> AsyncClient:
|
def ollama_client(self) -> AsyncOllamaClient:
|
||||||
# ollama client attaches itself to the current event loop (sadly?)
|
# ollama client attaches itself to the current event loop (sadly?)
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
if loop not in self._clients:
|
if loop not in self._clients:
|
||||||
self._clients[loop] = AsyncClient(host=self.config.url)
|
self._clients[loop] = AsyncOllamaClient(host=self.config.url)
|
||||||
return self._clients[loop]
|
return self._clients[loop]
|
||||||
|
|
||||||
@property
|
def get_api_key(self):
|
||||||
def openai_client(self) -> AsyncOpenAI:
|
return "NO_KEY"
|
||||||
if self._openai_client is None:
|
|
||||||
url = self.config.url.rstrip("/")
|
def get_base_url(self):
|
||||||
self._openai_client = AsyncOpenAI(base_url=f"{url}/v1", api_key="ollama")
|
return self.config.url.rstrip("/") + "/v1"
|
||||||
return self._openai_client
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
|
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
|
||||||
|
@ -129,7 +122,7 @@ class OllamaInferenceAdapter(
|
||||||
|
|
||||||
async def list_models(self) -> list[Model] | None:
|
async def list_models(self) -> list[Model] | None:
|
||||||
provider_id = self.__provider_id__
|
provider_id = self.__provider_id__
|
||||||
response = await self.client.list()
|
response = await self.ollama_client.list()
|
||||||
|
|
||||||
# always add the two embedding models which can be pulled on demand
|
# always add the two embedding models which can be pulled on demand
|
||||||
models = [
|
models = [
|
||||||
|
@ -189,7 +182,7 @@ class OllamaInferenceAdapter(
|
||||||
HealthResponse: A dictionary containing the health status.
|
HealthResponse: A dictionary containing the health status.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
await self.client.ps()
|
await self.ollama_client.ps()
|
||||||
return HealthResponse(status=HealthStatus.OK)
|
return HealthResponse(status=HealthStatus.OK)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||||
|
@ -238,7 +231,7 @@ class OllamaInferenceAdapter(
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
async def _generate_and_convert_to_openai_compat():
|
async def _generate_and_convert_to_openai_compat():
|
||||||
s = await self.client.generate(**params)
|
s = await self.ollama_client.generate(**params)
|
||||||
async for chunk in s:
|
async for chunk in s:
|
||||||
choice = OpenAICompatCompletionChoice(
|
choice = OpenAICompatCompletionChoice(
|
||||||
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
||||||
|
@ -254,7 +247,7 @@ class OllamaInferenceAdapter(
|
||||||
|
|
||||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = await self.client.generate(**params)
|
r = await self.ollama_client.generate(**params)
|
||||||
|
|
||||||
choice = OpenAICompatCompletionChoice(
|
choice = OpenAICompatCompletionChoice(
|
||||||
finish_reason=r["done_reason"] if r["done"] else None,
|
finish_reason=r["done_reason"] if r["done"] else None,
|
||||||
|
@ -346,9 +339,9 @@ class OllamaInferenceAdapter(
|
||||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
r = await self.client.chat(**params)
|
r = await self.ollama_client.chat(**params)
|
||||||
else:
|
else:
|
||||||
r = await self.client.generate(**params)
|
r = await self.ollama_client.generate(**params)
|
||||||
|
|
||||||
if "message" in r:
|
if "message" in r:
|
||||||
choice = OpenAICompatCompletionChoice(
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
@ -372,9 +365,9 @@ class OllamaInferenceAdapter(
|
||||||
|
|
||||||
async def _generate_and_convert_to_openai_compat():
|
async def _generate_and_convert_to_openai_compat():
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
s = await self.client.chat(**params)
|
s = await self.ollama_client.chat(**params)
|
||||||
else:
|
else:
|
||||||
s = await self.client.generate(**params)
|
s = await self.ollama_client.generate(**params)
|
||||||
async for chunk in s:
|
async for chunk in s:
|
||||||
if "message" in chunk:
|
if "message" in chunk:
|
||||||
choice = OpenAICompatCompletionChoice(
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
@ -407,7 +400,7 @@ class OllamaInferenceAdapter(
|
||||||
assert all(not content_has_media(content) for content in contents), (
|
assert all(not content_has_media(content) for content in contents), (
|
||||||
"Ollama does not support media for embeddings"
|
"Ollama does not support media for embeddings"
|
||||||
)
|
)
|
||||||
response = await self.client.embed(
|
response = await self.ollama_client.embed(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
input=[interleaved_content_as_str(content) for content in contents],
|
input=[interleaved_content_as_str(content) for content in contents],
|
||||||
)
|
)
|
||||||
|
@ -422,14 +415,14 @@ class OllamaInferenceAdapter(
|
||||||
pass # Ignore statically unknown model, will check live listing
|
pass # Ignore statically unknown model, will check live listing
|
||||||
|
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
response = await self.client.list()
|
response = await self.ollama_client.list()
|
||||||
if model.provider_resource_id not in [m.model for m in response.models]:
|
if model.provider_resource_id not in [m.model for m in response.models]:
|
||||||
await self.client.pull(model.provider_resource_id)
|
await self.ollama_client.pull(model.provider_resource_id)
|
||||||
|
|
||||||
# we use list() here instead of ps() -
|
# we use list() here instead of ps() -
|
||||||
# - ps() only lists running models, not available models
|
# - ps() only lists running models, not available models
|
||||||
# - models not currently running are run by the ollama server as needed
|
# - models not currently running are run by the ollama server as needed
|
||||||
response = await self.client.list()
|
response = await self.ollama_client.list()
|
||||||
available_models = [m.model for m in response.models]
|
available_models = [m.model for m in response.models]
|
||||||
|
|
||||||
provider_resource_id = model.provider_resource_id
|
provider_resource_id = model.provider_resource_id
|
||||||
|
@ -448,90 +441,6 @@ class OllamaInferenceAdapter(
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def openai_embeddings(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
input: str | list[str],
|
|
||||||
encoding_format: str | None = "float",
|
|
||||||
dimensions: int | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIEmbeddingsResponse:
|
|
||||||
model_obj = await self._get_model(model)
|
|
||||||
if model_obj.provider_resource_id is None:
|
|
||||||
raise ValueError(f"Model {model} has no provider_resource_id set")
|
|
||||||
|
|
||||||
# Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters
|
|
||||||
params = prepare_openai_embeddings_params(
|
|
||||||
model=model_obj.provider_resource_id,
|
|
||||||
input=input,
|
|
||||||
encoding_format=encoding_format,
|
|
||||||
dimensions=dimensions,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await self.openai_client.embeddings.create(**params)
|
|
||||||
data = b64_encode_openai_embeddings_response(response.data, encoding_format)
|
|
||||||
|
|
||||||
usage = OpenAIEmbeddingUsage(
|
|
||||||
prompt_tokens=response.usage.prompt_tokens,
|
|
||||||
total_tokens=response.usage.total_tokens,
|
|
||||||
)
|
|
||||||
# TODO: Investigate why model_obj.identifier is used instead of response.model
|
|
||||||
return OpenAIEmbeddingsResponse(
|
|
||||||
data=data,
|
|
||||||
model=model_obj.identifier,
|
|
||||||
usage=usage,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def openai_completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
prompt: str | list[str] | list[int] | list[list[int]],
|
|
||||||
best_of: int | None = None,
|
|
||||||
echo: bool | None = None,
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
guided_choice: list[str] | None = None,
|
|
||||||
prompt_logprobs: int | None = None,
|
|
||||||
suffix: str | None = None,
|
|
||||||
) -> OpenAICompletion:
|
|
||||||
if not isinstance(prompt, str):
|
|
||||||
raise ValueError("Ollama does not support non-string prompts for completion")
|
|
||||||
|
|
||||||
model_obj = await self._get_model(model)
|
|
||||||
params = await prepare_openai_completion_params(
|
|
||||||
model=model_obj.provider_resource_id,
|
|
||||||
prompt=prompt,
|
|
||||||
best_of=best_of,
|
|
||||||
echo=echo,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
suffix=suffix,
|
|
||||||
)
|
|
||||||
return await self.openai_client.completions.create(**params) # type: ignore
|
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -599,25 +508,7 @@ class OllamaInferenceAdapter(
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
response = await self.openai_client.chat.completions.create(**params)
|
return await OpenAIMixin.openai_chat_completion(self, **params)
|
||||||
return await self._adjust_ollama_chat_completion_response_ids(response)
|
|
||||||
|
|
||||||
async def _adjust_ollama_chat_completion_response_ids(
|
|
||||||
self,
|
|
||||||
response: OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk],
|
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
|
||||||
id = f"chatcmpl-{uuid.uuid4()}"
|
|
||||||
if isinstance(response, AsyncIterator):
|
|
||||||
|
|
||||||
async def stream_with_chunk_ids() -> AsyncIterator[OpenAIChatCompletionChunk]:
|
|
||||||
async for chunk in response:
|
|
||||||
chunk.id = id
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return stream_with_chunk_ids()
|
|
||||||
else:
|
|
||||||
response.id = id
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
||||||
|
|
|
@ -4,15 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from .config import OpenAIConfig
|
from .config import OpenAIConfig
|
||||||
|
|
||||||
|
|
||||||
class OpenAIProviderDataValidator(BaseModel):
|
|
||||||
openai_api_key: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: OpenAIConfig, _deps):
|
async def get_adapter_impl(config: OpenAIConfig, _deps):
|
||||||
from .openai import OpenAIInferenceAdapter
|
from .openai import OpenAIInferenceAdapter
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,7 @@ class InferenceStore:
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Create the necessary tables if they don't exist."""
|
"""Create the necessary tables if they don't exist."""
|
||||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
|
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
|
||||||
await self.sql_store.create_table(
|
await self.sql_store.create_table(
|
||||||
"chat_completions",
|
"chat_completions",
|
||||||
{
|
{
|
||||||
|
@ -202,7 +202,6 @@ class InferenceStore:
|
||||||
order_by=[("created", order.value)],
|
order_by=[("created", order.value)],
|
||||||
cursor=("id", after) if after else None,
|
cursor=("id", after) if after else None,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
policy=self.policy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
data = [
|
data = [
|
||||||
|
@ -229,7 +228,6 @@ class InferenceStore:
|
||||||
row = await self.sql_store.fetch_one(
|
row = await self.sql_store.fetch_one(
|
||||||
table="chat_completions",
|
table="chat_completions",
|
||||||
where={"id": completion_id},
|
where={"id": completion_id},
|
||||||
policy=self.policy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
|
|
|
@ -103,7 +103,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
Model(
|
Model(
|
||||||
identifier=id,
|
identifier=id,
|
||||||
provider_resource_id=entry.provider_model_id,
|
provider_resource_id=entry.provider_model_id,
|
||||||
model_type=ModelType.llm,
|
model_type=entry.model_type,
|
||||||
metadata=entry.metadata,
|
metadata=entry.metadata,
|
||||||
provider_id=self.__provider_id__,
|
provider_id=self.__provider_id__,
|
||||||
)
|
)
|
||||||
|
|
|
@ -28,8 +28,7 @@ class ResponsesStore:
|
||||||
sql_store_config = SqliteSqlStoreConfig(
|
sql_store_config = SqliteSqlStoreConfig(
|
||||||
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||||
)
|
)
|
||||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config))
|
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy)
|
||||||
self.policy = policy
|
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Create the necessary tables if they don't exist."""
|
"""Create the necessary tables if they don't exist."""
|
||||||
|
@ -87,7 +86,6 @@ class ResponsesStore:
|
||||||
order_by=[("created_at", order.value)],
|
order_by=[("created_at", order.value)],
|
||||||
cursor=("id", after) if after else None,
|
cursor=("id", after) if after else None,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
policy=self.policy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
|
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
|
||||||
|
@ -105,7 +103,6 @@ class ResponsesStore:
|
||||||
row = await self.sql_store.fetch_one(
|
row = await self.sql_store.fetch_one(
|
||||||
"openai_responses",
|
"openai_responses",
|
||||||
where={"id": response_id},
|
where={"id": response_id},
|
||||||
policy=self.policy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
|
@ -116,7 +113,7 @@ class ResponsesStore:
|
||||||
return OpenAIResponseObjectWithInput(**row["response_object"])
|
return OpenAIResponseObjectWithInput(**row["response_object"])
|
||||||
|
|
||||||
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
|
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||||
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy)
|
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
|
||||||
if not row:
|
if not row:
|
||||||
raise ValueError(f"Response with id {response_id} not found")
|
raise ValueError(f"Response with id {response_id} not found")
|
||||||
await self.sql_store.delete("openai_responses", where={"id": response_id})
|
await self.sql_store.delete("openai_responses", where={"id": response_id})
|
||||||
|
|
|
@ -53,13 +53,15 @@ class AuthorizedSqlStore:
|
||||||
access control policies, user attribute capture, and SQL filtering optimization.
|
access control policies, user attribute capture, and SQL filtering optimization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, sql_store: SqlStore):
|
def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
|
||||||
"""
|
"""
|
||||||
Initialize the authorization layer.
|
Initialize the authorization layer.
|
||||||
|
|
||||||
:param sql_store: Base SqlStore implementation to wrap
|
:param sql_store: Base SqlStore implementation to wrap
|
||||||
|
:param policy: Access control policy to use for authorization
|
||||||
"""
|
"""
|
||||||
self.sql_store = sql_store
|
self.sql_store = sql_store
|
||||||
|
self.policy = policy
|
||||||
self._detect_database_type()
|
self._detect_database_type()
|
||||||
self._validate_sql_optimized_policy()
|
self._validate_sql_optimized_policy()
|
||||||
|
|
||||||
|
@ -117,14 +119,13 @@ class AuthorizedSqlStore:
|
||||||
async def fetch_all(
|
async def fetch_all(
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
policy: list[AccessRule],
|
|
||||||
where: Mapping[str, Any] | None = None,
|
where: Mapping[str, Any] | None = None,
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||||
cursor: tuple[str, str] | None = None,
|
cursor: tuple[str, str] | None = None,
|
||||||
) -> PaginatedResponse:
|
) -> PaginatedResponse:
|
||||||
"""Fetch all rows with automatic access control filtering."""
|
"""Fetch all rows with automatic access control filtering."""
|
||||||
access_where = self._build_access_control_where_clause(policy)
|
access_where = self._build_access_control_where_clause(self.policy)
|
||||||
rows = await self.sql_store.fetch_all(
|
rows = await self.sql_store.fetch_all(
|
||||||
table=table,
|
table=table,
|
||||||
where=where,
|
where=where,
|
||||||
|
@ -146,7 +147,7 @@ class AuthorizedSqlStore:
|
||||||
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
|
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_action_allowed(policy, Action.READ, sql_record, current_user):
|
if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
|
||||||
filtered_rows.append(row)
|
filtered_rows.append(row)
|
||||||
|
|
||||||
return PaginatedResponse(
|
return PaginatedResponse(
|
||||||
|
@ -157,14 +158,12 @@ class AuthorizedSqlStore:
|
||||||
async def fetch_one(
|
async def fetch_one(
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
policy: list[AccessRule],
|
|
||||||
where: Mapping[str, Any] | None = None,
|
where: Mapping[str, Any] | None = None,
|
||||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Fetch one row with automatic access control checking."""
|
"""Fetch one row with automatic access control checking."""
|
||||||
results = await self.fetch_all(
|
results = await self.fetch_all(
|
||||||
table=table,
|
table=table,
|
||||||
policy=policy,
|
|
||||||
where=where,
|
where=where,
|
||||||
limit=1,
|
limit=1,
|
||||||
order_by=order_by,
|
order_by=order_by,
|
||||||
|
|
|
@ -203,7 +203,12 @@ def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
|
||||||
- '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
|
- '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
|
||||||
Returns a list of unique identifiers or None if structure doesn't match.
|
Returns a list of unique identifiers or None if structure doesn't match.
|
||||||
"""
|
"""
|
||||||
items = response["body"]
|
if "models" in response["body"]:
|
||||||
|
# ollama
|
||||||
|
items = response["body"]["models"]
|
||||||
|
else:
|
||||||
|
# openai
|
||||||
|
items = response["body"]
|
||||||
idents = [m.model if endpoint == "/api/tags" else m.id for m in items]
|
idents = [m.model if endpoint == "/api/tags" else m.id for m in items]
|
||||||
return sorted(set(idents))
|
return sorted(set(idents))
|
||||||
|
|
||||||
|
|
1724
llama_stack/ui/package-lock.json
generated
1724
llama_stack/ui/package-lock.json
generated
File diff suppressed because it is too large
Load diff
|
@ -14,7 +14,7 @@
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@radix-ui/react-collapsible": "^1.1.12",
|
"@radix-ui/react-collapsible": "^1.1.12",
|
||||||
"@radix-ui/react-dialog": "^1.1.13",
|
"@radix-ui/react-dialog": "^1.1.15",
|
||||||
"@radix-ui/react-dropdown-menu": "^2.1.16",
|
"@radix-ui/react-dropdown-menu": "^2.1.16",
|
||||||
"@radix-ui/react-select": "^2.2.6",
|
"@radix-ui/react-select": "^2.2.6",
|
||||||
"@radix-ui/react-separator": "^1.1.7",
|
"@radix-ui/react-separator": "^1.1.7",
|
||||||
|
@ -32,7 +32,7 @@
|
||||||
"react-dom": "^19.1.1",
|
"react-dom": "^19.1.1",
|
||||||
"react-markdown": "^10.1.0",
|
"react-markdown": "^10.1.0",
|
||||||
"remark-gfm": "^4.0.1",
|
"remark-gfm": "^4.0.1",
|
||||||
"remeda": "^2.30.0",
|
"remeda": "^2.32.0",
|
||||||
"shiki": "^1.29.2",
|
"shiki": "^1.29.2",
|
||||||
"sonner": "^2.0.7",
|
"sonner": "^2.0.7",
|
||||||
"tailwind-merge": "^3.3.1"
|
"tailwind-merge": "^3.3.1"
|
||||||
|
@ -52,7 +52,7 @@
|
||||||
"eslint-config-prettier": "^10.1.8",
|
"eslint-config-prettier": "^10.1.8",
|
||||||
"eslint-plugin-prettier": "^5.5.4",
|
"eslint-plugin-prettier": "^5.5.4",
|
||||||
"jest": "^29.7.0",
|
"jest": "^29.7.0",
|
||||||
"jest-environment-jsdom": "^29.7.0",
|
"jest-environment-jsdom": "^30.1.2",
|
||||||
"prettier": "3.6.2",
|
"prettier": "3.6.2",
|
||||||
"tailwindcss": "^4",
|
"tailwindcss": "^4",
|
||||||
"ts-node": "^10.9.2",
|
"ts-node": "^10.9.2",
|
||||||
|
|
9
tests/external/kaze.yaml
vendored
9
tests/external/kaze.yaml
vendored
|
@ -1,6 +1,5 @@
|
||||||
adapter:
|
adapter_type: kaze
|
||||||
adapter_type: kaze
|
pip_packages: ["tests/external/llama-stack-provider-kaze"]
|
||||||
pip_packages: ["tests/external/llama-stack-provider-kaze"]
|
config_class: llama_stack_provider_kaze.config.KazeProviderConfig
|
||||||
config_class: llama_stack_provider_kaze.config.KazeProviderConfig
|
module: llama_stack_provider_kaze
|
||||||
module: llama_stack_provider_kaze
|
|
||||||
optional_api_dependencies: []
|
optional_api_dependencies: []
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import AdapterSpec, Api, ProviderSpec, RemoteProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec, RemoteProviderSpec
|
||||||
from llama_stack.schema_utils import webmethod
|
from llama_stack.schema_utils import webmethod
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,12 +16,9 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
api=Api.weather,
|
api=Api.weather,
|
||||||
provider_type="remote::kaze",
|
provider_type="remote::kaze",
|
||||||
config_class="llama_stack_provider_kaze.KazeProviderConfig",
|
config_class="llama_stack_provider_kaze.KazeProviderConfig",
|
||||||
adapter=AdapterSpec(
|
adapter_type="kaze",
|
||||||
adapter_type="kaze",
|
module="llama_stack_provider_kaze",
|
||||||
module="llama_stack_provider_kaze",
|
pip_packages=["llama_stack_provider_kaze"],
|
||||||
pip_packages=["llama_stack_provider_kaze"],
|
|
||||||
config_class="llama_stack_provider_kaze.KazeProviderConfig",
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -33,6 +33,7 @@ def skip_if_model_doesnt_support_user_param(client, model_id):
|
||||||
provider = provider_from_model(client, model_id)
|
provider = provider_from_model(client, model_id)
|
||||||
if provider.provider_type in (
|
if provider.provider_type in (
|
||||||
"remote::together", # service returns 400
|
"remote::together", # service returns 400
|
||||||
|
"remote::fireworks", # service returns 400 malformed input
|
||||||
):
|
):
|
||||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support user param.")
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support user param.")
|
||||||
|
|
||||||
|
@ -41,6 +42,7 @@ def skip_if_model_doesnt_support_encoding_format_base64(client, model_id):
|
||||||
provider = provider_from_model(client, model_id)
|
provider = provider_from_model(client, model_id)
|
||||||
if provider.provider_type in (
|
if provider.provider_type in (
|
||||||
"remote::together", # param silently ignored, always returns floats
|
"remote::together", # param silently ignored, always returns floats
|
||||||
|
"remote::fireworks", # param silently ignored, always returns list of floats
|
||||||
):
|
):
|
||||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.")
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.")
|
||||||
|
|
||||||
|
@ -287,7 +289,6 @@ def test_openai_embeddings_base64_batch_processing(compat_client, client_with_mo
|
||||||
input=input_texts,
|
input=input_texts,
|
||||||
encoding_format="base64",
|
encoding_format="base64",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate response structure
|
# Validate response structure
|
||||||
assert response.object == "list"
|
assert response.object == "list"
|
||||||
assert response.model == embedding_model_id
|
assert response.model == embedding_model_id
|
||||||
|
|
|
@ -57,7 +57,7 @@ def authorized_store(backend_config):
|
||||||
config = config_func()
|
config = config_func()
|
||||||
|
|
||||||
base_sqlstore = sqlstore_impl(config)
|
base_sqlstore = sqlstore_impl(config)
|
||||||
authorized_store = AuthorizedSqlStore(base_sqlstore)
|
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||||
|
|
||||||
yield authorized_store
|
yield authorized_store
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
||||||
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
|
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
|
||||||
|
|
||||||
# Test fetching with no user - should not error on JSON comparison
|
# Test fetching with no user - should not error on JSON comparison
|
||||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
result = await authorized_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
assert result.data[0]["id"] == "1"
|
assert result.data[0]["id"] == "1"
|
||||||
assert result.data[0]["access_attributes"] is None
|
assert result.data[0]["access_attributes"] is None
|
||||||
|
@ -119,7 +119,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
||||||
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
|
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
|
||||||
|
|
||||||
# Fetch all - admin should see both
|
# Fetch all - admin should see both
|
||||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
result = await authorized_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 2
|
assert len(result.data) == 2
|
||||||
|
|
||||||
# Test with non-admin user
|
# Test with non-admin user
|
||||||
|
@ -127,7 +127,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
||||||
mock_get_authenticated_user.return_value = regular_user
|
mock_get_authenticated_user.return_value = regular_user
|
||||||
|
|
||||||
# Should only see public record
|
# Should only see public record
|
||||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
result = await authorized_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
assert result.data[0]["id"] == "1"
|
assert result.data[0]["id"] == "1"
|
||||||
|
|
||||||
|
@ -156,7 +156,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
||||||
|
|
||||||
# Now test with the multi-user who has both roles=admin and teams=dev
|
# Now test with the multi-user who has both roles=admin and teams=dev
|
||||||
mock_get_authenticated_user.return_value = multi_user
|
mock_get_authenticated_user.return_value = multi_user
|
||||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
result = await authorized_store.fetch_all(table_name)
|
||||||
|
|
||||||
# Should see:
|
# Should see:
|
||||||
# - public record (1) - no access_attributes
|
# - public record (1) - no access_attributes
|
||||||
|
@ -217,21 +217,24 @@ async def test_user_ownership_policy(mock_get_authenticated_user, authorized_sto
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Create a new authorized store with the owner-only policy
|
||||||
|
owner_only_store = AuthorizedSqlStore(authorized_store.sql_store, owner_only_policy)
|
||||||
|
|
||||||
# Test user1 access - should only see their own record
|
# Test user1 access - should only see their own record
|
||||||
mock_get_authenticated_user.return_value = user1
|
mock_get_authenticated_user.return_value = user1
|
||||||
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
|
result = await owner_only_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
|
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
|
||||||
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
|
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
|
||||||
|
|
||||||
# Test user2 access - should only see their own record
|
# Test user2 access - should only see their own record
|
||||||
mock_get_authenticated_user.return_value = user2
|
mock_get_authenticated_user.return_value = user2
|
||||||
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
|
result = await owner_only_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
|
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
|
||||||
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
|
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
|
||||||
|
|
||||||
# Test with anonymous user - should see no records
|
# Test with anonymous user - should see no records
|
||||||
mock_get_authenticated_user.return_value = None
|
mock_get_authenticated_user.return_value = None
|
||||||
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
|
result = await owner_only_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
|
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
|
990
tests/integration/recordings/responses/3a81146f2afa.json
Normal file
990
tests/integration/recordings/responses/3a81146f2afa.json
Normal file
|
@ -0,0 +1,990 @@
|
||||||
|
{
|
||||||
|
"request": {
|
||||||
|
"method": "POST",
|
||||||
|
"url": "http://0.0.0.0:11434/v1/v1/completions",
|
||||||
|
"headers": {},
|
||||||
|
"body": {
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"prompt": "Respond to this question and explain your answer. Complete the sentence using one word: Roses are red, violets are ",
|
||||||
|
"max_tokens": 50,
|
||||||
|
"stream": true,
|
||||||
|
"extra_body": {}
|
||||||
|
},
|
||||||
|
"endpoint": "/v1/completions",
|
||||||
|
"model": "llama3.2:3b-instruct-fp16"
|
||||||
|
},
|
||||||
|
"response": {
|
||||||
|
"body": [
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "Blue"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": ".\n\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "The"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " completed"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " sentence"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " is"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " a"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " well"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "-known"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " phrase"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " from"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " a"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " traditional"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " English"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " poem"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": ":\n\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "\""
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "R"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "oses"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " are"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " red"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": ","
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " v"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "io"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "lets"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " are"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " blue"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": ",\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "Sugar"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " is"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " sweet"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": ","
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " and"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " so"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " are"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " you"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": ".\""
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " However"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": ","
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " in"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " many"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " variations"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " of"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " this"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " poem"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": ","
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " the"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " line"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " \""
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "vio"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-439",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": ""
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"is_streaming": true
|
||||||
|
}
|
||||||
|
}
|
43
tests/integration/recordings/responses/6412295819a1.json
Normal file
43
tests/integration/recordings/responses/6412295819a1.json
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
{
|
||||||
|
"request": {
|
||||||
|
"method": "POST",
|
||||||
|
"url": "http://0.0.0.0:11434/v1/v1/completions",
|
||||||
|
"headers": {},
|
||||||
|
"body": {
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"prompt": "Respond to this question and explain your answer. Complete the sentence using one word: Roses are red, violets are ",
|
||||||
|
"stream": false,
|
||||||
|
"extra_body": {}
|
||||||
|
},
|
||||||
|
"endpoint": "/v1/completions",
|
||||||
|
"model": "llama3.2:3b-instruct-fp16"
|
||||||
|
},
|
||||||
|
"response": {
|
||||||
|
"body": {
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-104",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "blue.\n\nI completed the sentence with \"blue\" because it is a common completion used to complete the traditional nursery rhyme, which ends with:\n\nRoses are red,\nViolets are blue.\n\nThe complete rhyme is often remembered and recited as follows:\n\nRoses are red,\nViolets are blue,\nSugar is sweet,\nAnd so are you!"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857132,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 72,
|
||||||
|
"prompt_tokens": 50,
|
||||||
|
"total_tokens": 122,
|
||||||
|
"completion_tokens_details": null,
|
||||||
|
"prompt_tokens_details": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"is_streaming": false
|
||||||
|
}
|
||||||
|
}
|
43
tests/integration/recordings/responses/ecae140151d1.json
Normal file
43
tests/integration/recordings/responses/ecae140151d1.json
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
{
|
||||||
|
"request": {
|
||||||
|
"method": "POST",
|
||||||
|
"url": "http://0.0.0.0:11434/v1/v1/completions",
|
||||||
|
"headers": {},
|
||||||
|
"body": {
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"prompt": "Say completions",
|
||||||
|
"max_tokens": 20,
|
||||||
|
"extra_body": {}
|
||||||
|
},
|
||||||
|
"endpoint": "/v1/completions",
|
||||||
|
"model": "llama3.2:3b-instruct-fp16"
|
||||||
|
},
|
||||||
|
"response": {
|
||||||
|
"body": {
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-406",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "Sure, I'd be happy to provide some definitions and examples of related words or phrases.\n\nTo better"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1757857133,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 20,
|
||||||
|
"prompt_tokens": 28,
|
||||||
|
"total_tokens": 48,
|
||||||
|
"completion_tokens_details": null,
|
||||||
|
"prompt_tokens_details": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"is_streaming": false
|
||||||
|
}
|
||||||
|
}
|
|
@ -115,6 +115,15 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
|
||||||
"text_model": "cerebras/llama-3.3-70b",
|
"text_model": "cerebras/llama-3.3-70b",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
"fireworks": Setup(
|
||||||
|
name="fireworks",
|
||||||
|
description="Fireworks provider with a text model",
|
||||||
|
defaults={
|
||||||
|
"text_model": "accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||||
|
"vision_model": "accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
|
||||||
|
"embedding_model": "nomic-ai/nomic-embed-text-v1.5",
|
||||||
|
},
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -66,10 +66,9 @@ def base_config(tmp_path):
|
||||||
def provider_spec_yaml():
|
def provider_spec_yaml():
|
||||||
"""Common provider spec YAML for testing."""
|
"""Common provider spec YAML for testing."""
|
||||||
return """
|
return """
|
||||||
adapter:
|
adapter_type: test_provider
|
||||||
adapter_type: test_provider
|
config_class: test_provider.config.TestProviderConfig
|
||||||
config_class: test_provider.config.TestProviderConfig
|
module: test_provider
|
||||||
module: test_provider
|
|
||||||
api_dependencies:
|
api_dependencies:
|
||||||
- safety
|
- safety
|
||||||
"""
|
"""
|
||||||
|
@ -182,9 +181,9 @@ class TestProviderRegistry:
|
||||||
assert Api.inference in registry
|
assert Api.inference in registry
|
||||||
assert "remote::test_provider" in registry[Api.inference]
|
assert "remote::test_provider" in registry[Api.inference]
|
||||||
provider = registry[Api.inference]["remote::test_provider"]
|
provider = registry[Api.inference]["remote::test_provider"]
|
||||||
assert provider.adapter.adapter_type == "test_provider"
|
assert provider.adapter_type == "test_provider"
|
||||||
assert provider.adapter.module == "test_provider"
|
assert provider.module == "test_provider"
|
||||||
assert provider.adapter.config_class == "test_provider.config.TestProviderConfig"
|
assert provider.config_class == "test_provider.config.TestProviderConfig"
|
||||||
assert Api.safety in provider.api_dependencies
|
assert Api.safety in provider.api_dependencies
|
||||||
|
|
||||||
def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml):
|
def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml):
|
||||||
|
@ -246,8 +245,7 @@ class TestProviderRegistry:
|
||||||
"""Test handling of malformed remote provider spec (missing required fields)."""
|
"""Test handling of malformed remote provider spec (missing required fields)."""
|
||||||
remote_dir, _ = api_directories
|
remote_dir, _ = api_directories
|
||||||
malformed_spec = """
|
malformed_spec = """
|
||||||
adapter:
|
adapter_type: test_provider
|
||||||
adapter_type: test_provider
|
|
||||||
# Missing required fields
|
# Missing required fields
|
||||||
api_dependencies:
|
api_dependencies:
|
||||||
- safety
|
- safety
|
||||||
|
@ -270,7 +268,7 @@ pip_packages:
|
||||||
with open(inline_dir / "malformed.yaml", "w") as f:
|
with open(inline_dir / "malformed.yaml", "w") as f:
|
||||||
f.write(malformed_spec)
|
f.write(malformed_spec)
|
||||||
|
|
||||||
with pytest.raises(KeyError) as exc_info:
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
get_provider_registry(base_config)
|
get_provider_registry(base_config)
|
||||||
assert "config_class" in str(exc_info.value)
|
assert "config_class" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
|
@ -27,13 +27,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
mock_impls = {}
|
mock_impls = {}
|
||||||
mock_route_impls = RouteImpls({})
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
async def mock_construct_stack(config, custom_provider_registry):
|
class MockStack:
|
||||||
return mock_impls
|
def __init__(self, config, custom_provider_registry=None):
|
||||||
|
self.impls = mock_impls
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def mock_initialize_route_impls(impls):
|
def mock_initialize_route_impls(impls):
|
||||||
return mock_route_impls
|
return mock_route_impls
|
||||||
|
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
client = LlamaStackAsLibraryClient("ci-tests")
|
client = LlamaStackAsLibraryClient("ci-tests")
|
||||||
|
@ -46,13 +50,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
mock_impls = {}
|
mock_impls = {}
|
||||||
mock_route_impls = RouteImpls({})
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
async def mock_construct_stack(config, custom_provider_registry):
|
class MockStack:
|
||||||
return mock_impls
|
def __init__(self, config, custom_provider_registry=None):
|
||||||
|
self.impls = mock_impls
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def mock_initialize_route_impls(impls):
|
def mock_initialize_route_impls(impls):
|
||||||
return mock_route_impls
|
return mock_route_impls
|
||||||
|
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||||
|
@ -68,13 +76,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
mock_impls = {}
|
mock_impls = {}
|
||||||
mock_route_impls = RouteImpls({})
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
async def mock_construct_stack(config, custom_provider_registry):
|
class MockStack:
|
||||||
return mock_impls
|
def __init__(self, config, custom_provider_registry=None):
|
||||||
|
self.impls = mock_impls
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def mock_initialize_route_impls(impls):
|
def mock_initialize_route_impls(impls):
|
||||||
return mock_route_impls
|
return mock_route_impls
|
||||||
|
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
client = LlamaStackAsLibraryClient("ci-tests")
|
client = LlamaStackAsLibraryClient("ci-tests")
|
||||||
|
@ -90,13 +102,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
mock_impls = {}
|
mock_impls = {}
|
||||||
mock_route_impls = RouteImpls({})
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
async def mock_construct_stack(config, custom_provider_registry):
|
class MockStack:
|
||||||
return mock_impls
|
def __init__(self, config, custom_provider_registry=None):
|
||||||
|
self.impls = mock_impls
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def mock_initialize_route_impls(impls):
|
def mock_initialize_route_impls(impls):
|
||||||
return mock_route_impls
|
return mock_route_impls
|
||||||
|
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||||
|
@ -112,13 +128,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||||
mock_impls = {}
|
mock_impls = {}
|
||||||
mock_route_impls = RouteImpls({})
|
mock_route_impls = RouteImpls({})
|
||||||
|
|
||||||
async def mock_construct_stack(config, custom_provider_registry):
|
class MockStack:
|
||||||
return mock_impls
|
def __init__(self, config, custom_provider_registry=None):
|
||||||
|
self.impls = mock_impls
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def mock_initialize_route_impls(impls):
|
def mock_initialize_route_impls(impls):
|
||||||
return mock_route_impls
|
return mock_route_impls
|
||||||
|
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||||
|
|
||||||
sync_client = LlamaStackAsLibraryClient("ci-tests")
|
sync_client = LlamaStackAsLibraryClient("ci-tests")
|
||||||
|
|
|
@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
|
||||||
db_path=tmp_dir + "/" + db_name,
|
db_path=tmp_dir + "/" + db_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
sqlstore = AuthorizedSqlStore(base_sqlstore)
|
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||||
|
|
||||||
# Create table with access control
|
# Create table with access control
|
||||||
await sqlstore.create_table(
|
await sqlstore.create_table(
|
||||||
|
@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
|
||||||
mock_get_authenticated_user.return_value = admin_user
|
mock_get_authenticated_user.return_value = admin_user
|
||||||
|
|
||||||
# Admin should see both documents
|
# Admin should see both documents
|
||||||
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
|
result = await sqlstore.fetch_all("documents", where={"id": 1})
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
assert result.data[0]["title"] == "Admin Document"
|
assert result.data[0]["title"] == "Admin Document"
|
||||||
|
|
||||||
# User should only see their document
|
# User should only see their document
|
||||||
mock_get_authenticated_user.return_value = regular_user
|
mock_get_authenticated_user.return_value = regular_user
|
||||||
|
|
||||||
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
|
result = await sqlstore.fetch_all("documents", where={"id": 1})
|
||||||
assert len(result.data) == 0
|
assert len(result.data) == 0
|
||||||
|
|
||||||
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2})
|
result = await sqlstore.fetch_all("documents", where={"id": 2})
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
assert result.data[0]["title"] == "User Document"
|
assert result.data[0]["title"] == "User Document"
|
||||||
|
|
||||||
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1})
|
row = await sqlstore.fetch_one("documents", where={"id": 1})
|
||||||
assert row is None
|
assert row is None
|
||||||
|
|
||||||
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2})
|
row = await sqlstore.fetch_one("documents", where={"id": 2})
|
||||||
assert row is not None
|
assert row is not None
|
||||||
assert row["title"] == "User Document"
|
assert row["title"] == "User Document"
|
||||||
|
|
||||||
|
@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
||||||
db_path=tmp_dir + "/" + db_name,
|
db_path=tmp_dir + "/" + db_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
sqlstore = AuthorizedSqlStore(base_sqlstore)
|
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||||
|
|
||||||
await sqlstore.create_table(
|
await sqlstore.create_table(
|
||||||
table="resources",
|
table="resources",
|
||||||
|
@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
||||||
user = User(principal=user_data["principal"], attributes=user_data["attributes"])
|
user = User(principal=user_data["principal"], attributes=user_data["attributes"])
|
||||||
mock_get_authenticated_user.return_value = user
|
mock_get_authenticated_user.return_value = user
|
||||||
|
|
||||||
sql_results = await sqlstore.fetch_all("resources", policy=policy)
|
sql_results = await sqlstore.fetch_all("resources")
|
||||||
sql_ids = {row["id"] for row in sql_results.data}
|
sql_ids = {row["id"] for row in sql_results.data}
|
||||||
policy_ids = set()
|
policy_ids = set()
|
||||||
for scenario in test_scenarios:
|
for scenario in test_scenarios:
|
||||||
|
@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us
|
||||||
db_path=tmp_dir + "/" + db_name,
|
db_path=tmp_dir + "/" + db_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
authorized_store = AuthorizedSqlStore(base_sqlstore)
|
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||||
|
|
||||||
await authorized_store.create_table(
|
await authorized_store.create_table(
|
||||||
table="user_data",
|
table="user_data",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue