Merge branch 'main' into allow-dynamic-models-ollama

This commit is contained in:
Matthew Farrellee 2025-07-28 14:16:31 -04:00
commit 56476fa462
247 changed files with 9176 additions and 7177 deletions

View file

@ -43,10 +43,24 @@ class ModelsProtocolPrivate(Protocol):
-> Provider uses provider-model-id for inference
"""
# this should be called `on_model_register` or something like that.
# the provider should _not_ be able to change the object in this
# callback
async def register_model(self, model: Model) -> Model: ...
async def unregister_model(self, model_id: str) -> None: ...
# the Stack router will query each provider for their list of models
# if a `refresh_interval_seconds` is provided, this method will be called
# periodically to refresh the list of models
#
# NOTE: each model returned will be registered with the model registry. this means
# a callback to the `register_model()` method will be made. this is duplicative and
# may be removed in the future.
async def list_models(self) -> list[Model] | None: ...
async def should_refresh_models(self) -> bool: ...
class ShieldsProtocolPrivate(Protocol):
async def register_shield(self, shield: Shield) -> None: ...
@ -104,6 +118,19 @@ class ProviderSpec(BaseModel):
description="If this provider is deprecated and does NOT work, specify the error message here",
)
module: str | None = Field(
default=None,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
Example: `module: ramalama_stack`
""",
)
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
# used internally by the resolver; this is a hack for now
deps__: list[str] = Field(default_factory=list)
@ -113,7 +140,7 @@ class ProviderSpec(BaseModel):
class RoutingTable(Protocol):
def get_provider_impl(self, routing_key: str) -> Any: ...
async def get_provider_impl(self, routing_key: str) -> Any: ...
# TODO: this can now be inlined into RemoteProviderSpec
@ -124,7 +151,7 @@ class AdapterSpec(BaseModel):
description="Unique identifier for this adapter",
)
module: str = Field(
...,
default_factory=str,
description="""
Fully-qualified name of the module to import. The module is expected to have:
@ -162,14 +189,7 @@ The container image to use for this implementation. If one is provided, pip_pack
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
""",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_provider_impl(config, deps)`: returns the local implementation
""",
)
# module field is inherited from ProviderSpec
provider_data_validator: str | None = Field(
default=None,
)
@ -212,9 +232,7 @@ API responses, specify the adapter here.
def container_image(self) -> str | None:
return None
@property
def module(self) -> str:
return self.adapter.module
# module field is inherited from ProviderSpec
@property
def pip_packages(self) -> list[str]:
@ -226,14 +244,19 @@ API responses, specify the adapter here.
def remote_provider_spec(
api: Api, adapter: AdapterSpec, api_dependencies: list[Api] | None = None
api: Api,
adapter: AdapterSpec,
api_dependencies: list[Api] | None = None,
optional_api_dependencies: list[Api] | None = None,
) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api,
provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class,
module=adapter.module,
adapter=adapter,
api_dependencies=api_dependencies or [],
optional_api_dependencies=optional_api_dependencies or [],
)

View file

@ -10,6 +10,7 @@ import re
import secrets
import string
import uuid
import warnings
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
@ -911,8 +912,16 @@ async def load_data_from_url(url: str) -> str:
async def get_raw_document_text(document: Document) -> str:
if not document.mime_type.startswith("text/"):
# Handle deprecated text/yaml mime type with warning
if document.mime_type == "text/yaml":
warnings.warn(
"The 'text/yaml' MIME type is deprecated. Please use 'application/yaml' instead.",
DeprecationWarning,
stacklevel=2,
)
elif not (document.mime_type.startswith("text/") or document.mime_type == "application/yaml"):
raise ValueError(f"Unexpected document mime type: {document.mime_type}")
if isinstance(document.content, URL):
return await load_data_from_url(document.content.uri)
elif isinstance(document.content, str):

View file

@ -128,6 +128,11 @@ class AgentPersistence:
except Exception as e:
log.error(f"Error parsing turn: {e}")
continue
# The kvstore does not guarantee order, so we sort by started_at
# to ensure consistent ordering of turns.
turns.sort(key=lambda t: t.started_at)
return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:

View file

@ -102,6 +102,12 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group:
self.generator.stop()
async def should_refresh_models(self) -> bool:
return False
async def list_models(self) -> list[Model] | None:
return None
async def unregister_model(self, model_id: str) -> None:
pass

View file

@ -20,6 +20,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import ModelType
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
@ -41,6 +42,8 @@ class SentenceTransformersInferenceImpl(
InferenceProvider,
ModelsProtocolPrivate,
):
__provider_id__: str
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
self.config = config
@ -50,6 +53,22 @@ class SentenceTransformersInferenceImpl(
async def shutdown(self) -> None:
pass
async def should_refresh_models(self) -> bool:
return False
async def list_models(self) -> list[Model] | None:
return [
Model(
identifier="all-MiniLM-L6-v2",
provider_resource_id="all-MiniLM-L6-v2",
provider_id=self.__provider_id__,
metadata={
"embedding_dimension": 384,
},
model_type=ModelType.embedding,
),
]
async def register_model(self, model: Model) -> Model:
return model

View file

@ -146,9 +146,9 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
pass
async def register_shield(self, shield: Shield) -> None:
# Allow any model to be registered as a shield
# The model will be validated during runtime when making inference calls
pass
model_id = shield.provider_resource_id
if not model_id:
raise ValueError("Llama Guard shield must have a model id")
async def run_shield(
self,

View file

@ -11,19 +11,9 @@ from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SpanProcessor
from opentelemetry.trace.status import StatusCode
# Colors for console output
COLORS = {
"reset": "\033[0m",
"bold": "\033[1m",
"dim": "\033[2m",
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
}
from llama_stack.log import get_logger
logger = get_logger(name="console_span_processor", category="telemetry")
class ConsoleSpanProcessor(SpanProcessor):
@ -35,34 +25,21 @@ class ConsoleSpanProcessor(SpanProcessor):
return
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
print(
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{COLORS['magenta']}[START]{COLORS['reset']} "
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
)
logger.info(f"[dim]{timestamp}[/dim] [bold magenta][START][/bold magenta] [dim]{span.name}[/dim]")
def on_end(self, span: ReadableSpan) -> None:
if span.attributes and span.attributes.get("__autotraced__"):
return
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
span_context = (
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{COLORS['magenta']}[END]{COLORS['reset']} "
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
)
span_context = f"[dim]{timestamp}[/dim] [bold magenta][END][/bold magenta] [dim]{span.name}[/dim]"
if span.status.status_code == StatusCode.ERROR:
span_context += f"{COLORS['reset']} {COLORS['red']}[ERROR]{COLORS['reset']}"
span_context += " [bold red][ERROR][/bold red]"
elif span.status.status_code != StatusCode.UNSET:
span_context += f"{COLORS['reset']} [{span.status.status_code}]"
span_context += f" [{span.status.status_code}]"
duration_ms = (span.end_time - span.start_time) / 1e6
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)"
print(span_context)
span_context += f" ({duration_ms:.2f}ms)"
logger.info(span_context)
if self.print_attributes and span.attributes:
for key, value in span.attributes.items():
@ -71,31 +48,26 @@ class ConsoleSpanProcessor(SpanProcessor):
str_value = str(value)
if len(str_value) > 1000:
str_value = str_value[:997] + "..."
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}")
logger.info(f" [dim]{key}[/dim]: {str_value}")
for event in span.events:
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
severity = event.attributes.get("severity", "info")
message = event.attributes.get("message", event.name)
if isinstance(message, dict | list):
if isinstance(message, dict) or isinstance(message, list):
message = json.dumps(message, indent=2)
severity_colors = {
"error": f"{COLORS['bold']}{COLORS['red']}",
"warn": f"{COLORS['bold']}{COLORS['yellow']}",
"info": COLORS["white"],
"debug": COLORS["dim"],
}
msg_color = severity_colors.get(severity, COLORS["white"])
print(f" {event_time} {msg_color}[{severity.upper()}] {message}{COLORS['reset']}")
severity_color = {
"error": "red",
"warn": "yellow",
"info": "white",
"debug": "dim",
}.get(severity, "white")
logger.info(f" {event_time} [bold {severity_color}][{severity.upper()}][/bold {severity_color}] {message}")
if event.attributes:
for key, value in event.attributes.items():
if key.startswith("__") or key in ["message", "severity"]:
continue
print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}")
logger.info(f"/r[dim]{key}[/dim]: {value}")
def shutdown(self) -> None:
"""Shutdown the processor."""

View file

@ -16,6 +16,6 @@ async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]):
ChromaVectorIOAdapter,
)
impl = ChromaVectorIOAdapter(config, deps[Api.inference])
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -6,12 +6,25 @@
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class ChromaVectorIOConfig(BaseModel):
db_path: str
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
@classmethod
def sample_run_config(cls, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any) -> dict[str, Any]:
return {"db_path": db_path}
def sample_run_config(
cls, __distro_dir__: str, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any
) -> dict[str, Any]:
return {
"db_path": db_path,
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="chroma_inline_registry.db",
),
}

View file

@ -55,6 +55,11 @@ class FaissIndex(EmbeddingIndex):
self.kvstore = kvstore
self.bank_id = bank_id
# A list of chunk id's in the same order as they are in the index,
# must be updated when chunks are added or removed
self.chunk_id_lock = asyncio.Lock()
self.chunk_ids: list[Any] = []
@classmethod
async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
instance = cls(dimension, kvstore, bank_id)
@ -75,6 +80,7 @@ class FaissIndex(EmbeddingIndex):
buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
try:
self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False))
self.chunk_ids = [chunk.chunk_id for chunk in self.chunk_by_index.values()]
except Exception as e:
logger.debug(e, exc_info=True)
raise ValueError(
@ -114,11 +120,33 @@ class FaissIndex(EmbeddingIndex):
for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = chunk
self.index.add(np.array(embeddings).astype(np.float32))
async with self.chunk_id_lock:
self.index.add(np.array(embeddings).astype(np.float32))
self.chunk_ids.extend([chunk.chunk_id for chunk in chunks])
# Save updated index
await self._save_index()
async def delete_chunk(self, chunk_id: str) -> None:
if chunk_id not in self.chunk_ids:
return
async with self.chunk_id_lock:
index = self.chunk_ids.index(chunk_id)
self.index.remove_ids(np.array([index]))
new_chunk_by_index = {}
for idx, chunk in self.chunk_by_index.items():
# Shift all chunks after the removed chunk to the left
if idx > index:
new_chunk_by_index[idx - 1] = chunk
else:
new_chunk_by_index[idx] = chunk
self.chunk_by_index = new_chunk_by_index
self.chunk_ids.pop(index)
await self._save_index()
async def query_vector(
self,
embedding: NDArray,
@ -261,47 +289,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
return await index.query_chunks(query, params)
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file data to kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.set(key=key, value=json.dumps(file_info))
content_key = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}"
await self.kvstore.set(key=content_key, value=json.dumps(file_contents))
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
"""Load vector store file metadata from kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
stored_data = await self.kvstore.get(key)
return json.loads(stored_data) if stored_data else {}
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
"""Load vector store file contents from kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}"
stored_data = await self.kvstore.get(key)
return json.loads(stored_data) if stored_data else []
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
"""Update vector store file metadata in kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.set(key=key, value=json.dumps(file_info))
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store data from kvstore."""
assert self.kvstore is not None
keys_to_delete = [
f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}",
f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}",
]
for key in keys_to_delete:
try:
await self.kvstore.delete(key)
except Exception as e:
logger.warning(f"Failed to delete key {key}: {e}")
continue
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
"""Delete a chunk from a faiss index"""
faiss_index = self.cache[store_id].index
for chunk_id in chunk_ids:
await faiss_index.delete_chunk(chunk_id)

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import json
import logging
import re
import sqlite3
@ -426,6 +425,35 @@ class SQLiteVecIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete_chunk(self, chunk_id: str) -> None:
"""Remove a chunk from the SQLite vector store."""
def _delete_chunk():
connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor()
try:
cur.execute("BEGIN TRANSACTION")
# Delete from metadata table
cur.execute(f"DELETE FROM {self.metadata_table} WHERE id = ?", (chunk_id,))
# Delete from vector table
cur.execute(f"DELETE FROM {self.vector_table} WHERE id = ?", (chunk_id,))
# Delete from FTS table
cur.execute(f"DELETE FROM {self.fts_table} WHERE id = ?", (chunk_id,))
connection.commit()
except Exception as e:
connection.rollback()
logger.error(f"Error deleting chunk {chunk_id}: {e}")
raise
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete_chunk)
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
"""
@ -506,140 +534,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file metadata to SQLite database."""
def _create_or_store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
# Create a table to persist OpenAI vector store files.
cur.execute("""
CREATE TABLE IF NOT EXISTS openai_vector_store_files (
store_id TEXT,
file_id TEXT,
metadata TEXT,
PRIMARY KEY (store_id, file_id)
);
""")
cur.execute("""
CREATE TABLE IF NOT EXISTS openai_vector_store_files_contents (
store_id TEXT,
file_id TEXT,
contents TEXT,
PRIMARY KEY (store_id, file_id)
);
""")
connection.commit()
cur.execute(
"INSERT OR REPLACE INTO openai_vector_store_files (store_id, file_id, metadata) VALUES (?, ?, ?)",
(store_id, file_id, json.dumps(file_info)),
)
cur.execute(
"INSERT OR REPLACE INTO openai_vector_store_files_contents (store_id, file_id, contents) VALUES (?, ?, ?)",
(store_id, file_id, json.dumps(file_contents)),
)
connection.commit()
except Exception as e:
logger.error(f"Error saving openai vector store file {store_id} {file_id}: {e}")
raise
finally:
cur.close()
connection.close()
try:
await asyncio.to_thread(_create_or_store)
except Exception as e:
logger.error(f"Error saving openai vector store file {store_id} {file_id}: {e}")
raise
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
"""Load vector store file metadata from SQLite database."""
def _load():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"SELECT metadata FROM openai_vector_store_files WHERE store_id = ? AND file_id = ?",
(store_id, file_id),
)
row = cur.fetchone()
if row is None:
return None
(metadata,) = row
return metadata
finally:
cur.close()
connection.close()
stored_data = await asyncio.to_thread(_load)
return json.loads(stored_data) if stored_data else {}
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
"""Load vector store file contents from SQLite database."""
def _load():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"SELECT contents FROM openai_vector_store_files_contents WHERE store_id = ? AND file_id = ?",
(store_id, file_id),
)
row = cur.fetchone()
if row is None:
return None
(contents,) = row
return contents
finally:
cur.close()
connection.close()
stored_contents = await asyncio.to_thread(_load)
return json.loads(stored_contents) if stored_contents else []
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
"""Update vector store file metadata in SQLite database."""
def _update():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"UPDATE openai_vector_store_files SET metadata = ? WHERE store_id = ? AND file_id = ?",
(json.dumps(file_info), store_id, file_id),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_update)
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store file metadata from SQLite database."""
def _delete():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"DELETE FROM openai_vector_store_files WHERE store_id = ? AND file_id = ?", (store_id, file_id)
)
cur.execute(
"DELETE FROM openai_vector_store_files_contents WHERE store_id = ? AND file_id = ?",
(store_id, file_id),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete)
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
@ -655,3 +549,13 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
"""Delete a chunk from a sqlite_vec index."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")
for chunk_id in chunk_ids:
# Use the index's delete_chunk method
await index.index.delete_chunk(chunk_id)

View file

@ -224,17 +224,6 @@ def available_providers() -> list[ProviderSpec]:
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="fireworks-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.fireworks_openai_compat",
config_class="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksProviderDataValidator",
description="Fireworks AI OpenAI-compatible provider for using Fireworks models with OpenAI API format.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
@ -246,50 +235,6 @@ def available_providers() -> list[ProviderSpec]:
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="together-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.together_openai_compat",
config_class="llama_stack.providers.remote.inference.together_openai_compat.config.TogetherCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.together_openai_compat.config.TogetherProviderDataValidator",
description="Together AI OpenAI-compatible provider for using Together models with OpenAI API format.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="groq-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.groq_openai_compat",
config_class="llama_stack.providers.remote.inference.groq_openai_compat.config.GroqCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq_openai_compat.config.GroqProviderDataValidator",
description="Groq OpenAI-compatible provider for using Groq models with OpenAI API format.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="sambanova-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.sambanova_openai_compat",
config_class="llama_stack.providers.remote.inference.sambanova_openai_compat.config.SambaNovaCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.sambanova_openai_compat.config.SambaNovaProviderDataValidator",
description="SambaNova OpenAI-compatible provider for using SambaNova models with OpenAI API format.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="cerebras-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.cerebras_openai_compat",
config_class="llama_stack.providers.remote.inference.cerebras_openai_compat.config.CerebrasCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.cerebras_openai_compat.config.CerebrasProviderDataValidator",
description="Cerebras OpenAI-compatible provider for using Cerebras models with OpenAI API format.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(

View file

@ -395,7 +395,7 @@ That means you'll get fast and efficient vector retrieval.
To use PGVector in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use Faiss.
2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector).
3. Start storing and querying vectors.
## Installation
@ -410,6 +410,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
remote_provider_spec(
Api.vector_io,

View file

@ -15,6 +15,7 @@ class AnthropicInferenceAdapter(LiteLLMOpenAIMixin):
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
litellm_provider_name="anthropic",
api_key_from_config=config.api_key,
provider_data_api_key_field="anthropic_api_key",
)

View file

@ -26,7 +26,7 @@ class AnthropicConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY}", **kwargs) -> dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"api_key": api_key,
}

View file

@ -10,9 +10,9 @@ from llama_stack.providers.utils.inference.model_registry import (
)
LLM_MODEL_IDS = [
"anthropic/claude-3-5-sonnet-latest",
"anthropic/claude-3-7-sonnet-latest",
"anthropic/claude-3-5-haiku-latest",
"claude-3-5-sonnet-latest",
"claude-3-7-sonnet-latest",
"claude-3-5-haiku-latest",
]
SAFETY_MODELS_ENTRIES = []
@ -21,17 +21,17 @@ MODEL_ENTRIES = (
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
+ [
ProviderModelEntry(
provider_model_id="anthropic/voyage-3",
provider_model_id="voyage-3",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000},
),
ProviderModelEntry(
provider_model_id="anthropic/voyage-3-lite",
provider_model_id="voyage-3-lite",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 512, "context_length": 32000},
),
ProviderModelEntry(
provider_model_id="anthropic/voyage-code-3",
provider_model_id="voyage-code-3",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000},
),

View file

@ -63,18 +63,20 @@ class BedrockInferenceAdapter(
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self._config = config
self._client = create_bedrock_client(config)
self._client = None
@property
def client(self) -> BaseClient:
if self._client is None:
self._client = create_bedrock_client(self._config)
return self._client
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
self.client.close()
if self._client is not None:
self._client.close()
async def completion(
self,

View file

@ -65,6 +65,7 @@ class CerebrasInferenceAdapter(
)
self.config = config
# TODO: make this use provider data, etc. like other providers
self.client = AsyncCerebras(
base_url=self.config.base_url,
api_key=self.config.api_key.get_secret_value(),

View file

@ -26,7 +26,7 @@ class CerebrasImplConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"base_url": DEFAULT_BASE_URL,
"api_key": api_key,

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import CerebrasCompatConfig
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .cerebras import CerebrasCompatInferenceAdapter
adapter = CerebrasCompatInferenceAdapter(config)
return adapter

View file

@ -1,30 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.cerebras_openai_compat.config import CerebrasCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from ..cerebras.models import MODEL_ENTRIES
class CerebrasCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: CerebrasCompatConfig
def __init__(self, config: CerebrasCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="cerebras_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -1,38 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class CerebrasProviderDataValidator(BaseModel):
cerebras_api_key: str | None = Field(
default=None,
description="API key for Cerebras models",
)
@json_schema_type
class CerebrasCompatConfig(BaseModel):
api_key: str | None = Field(
default=None,
description="The Cerebras API key",
)
openai_compat_api_base: str = Field(
default="https://api.cerebras.ai/v1",
description="The URL for the Cerebras API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.cerebras.ai/v1",
"api_key": api_key,
}

View file

@ -25,8 +25,8 @@ class DatabricksImplConfig(BaseModel):
@classmethod
def sample_run_config(
cls,
url: str = "${env.DATABRICKS_URL}",
api_token: str = "${env.DATABRICKS_API_TOKEN}",
url: str = "${env.DATABRICKS_URL:=}",
api_token: str = "${env.DATABRICKS_API_TOKEN:=}",
**kwargs: Any,
) -> dict[str, Any]:
return {

View file

@ -6,13 +6,14 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class FireworksImplConfig(BaseModel):
class FireworksImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.fireworks.ai/inference/v1",
description="The URL for the Fireworks server",
@ -23,7 +24,7 @@ class FireworksImplConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.fireworks.ai/inference/v1",
"api_key": api_key,

View file

@ -70,7 +70,7 @@ logger = get_logger(name=__name__, category="inference")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
self.config = config
async def initialize(self) -> None:

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import FireworksCompatConfig
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .fireworks import FireworksCompatInferenceAdapter
adapter = FireworksCompatInferenceAdapter(config)
return adapter

View file

@ -1,38 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class FireworksProviderDataValidator(BaseModel):
fireworks_api_key: str | None = Field(
default=None,
description="API key for Fireworks models",
)
@json_schema_type
class FireworksCompatConfig(BaseModel):
api_key: str | None = Field(
default=None,
description="The Fireworks API key",
)
openai_compat_api_base: str = Field(
default="https://api.fireworks.ai/inference/v1",
description="The URL for the Fireworks API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.fireworks.ai/inference/v1",
"api_key": api_key,
}

View file

@ -1,30 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.fireworks_openai_compat.config import FireworksCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from ..fireworks.models import MODEL_ENTRIES
class FireworksCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: FireworksCompatConfig
def __init__(self, config: FireworksCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="fireworks_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -26,7 +26,7 @@ class GeminiConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY}", **kwargs) -> dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"api_key": api_key,
}

View file

@ -15,6 +15,7 @@ class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
litellm_provider_name="gemini",
api_key_from_config=config.api_key,
provider_data_api_key_field="gemini_api_key",
)

View file

@ -10,11 +10,11 @@ from llama_stack.providers.utils.inference.model_registry import (
)
LLM_MODEL_IDS = [
"gemini/gemini-1.5-flash",
"gemini/gemini-1.5-pro",
"gemini/gemini-2.0-flash",
"gemini/gemini-2.5-flash",
"gemini/gemini-2.5-pro",
"gemini-1.5-flash",
"gemini-1.5-pro",
"gemini-2.0-flash",
"gemini-2.5-flash",
"gemini-2.5-pro",
]
SAFETY_MODELS_ENTRIES = []
@ -23,7 +23,7 @@ MODEL_ENTRIES = (
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
+ [
ProviderModelEntry(
provider_model_id="gemini/text-embedding-004",
provider_model_id="text-embedding-004",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 768, "context_length": 2048},
),

View file

@ -32,7 +32,7 @@ class GroqConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.groq.com",
"api_key": api_key,

View file

@ -34,6 +34,7 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
litellm_provider_name="groq",
api_key_from_config=config.api_key,
provider_data_api_key_field="groq_api_key",
)
@ -96,7 +97,7 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
tool_choice = "required"
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id.replace("groq/", ""),
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,

View file

@ -14,19 +14,19 @@ SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"groq/llama3-8b-8192",
"llama3-8b-8192",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_entry(
"groq/llama-3.1-8b-instant",
"llama-3.1-8b-instant",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"groq/llama3-70b-8192",
"llama3-70b-8192",
CoreModelId.llama3_70b_instruct.value,
),
build_hf_repo_model_entry(
"groq/llama-3.3-70b-versatile",
"llama-3.3-70b-versatile",
CoreModelId.llama3_3_70b_instruct.value,
),
# Groq only contains a preview version for llama-3.2-3b
@ -34,23 +34,15 @@ MODEL_ENTRIES = [
# to pass the test fixture
# TODO(aidand): Replace this with a stable model once Groq supports it
build_hf_repo_model_entry(
"groq/llama-3.2-3b-preview",
"llama-3.2-3b-preview",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"groq/llama-4-scout-17b-16e-instruct",
"meta-llama/llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"groq/meta-llama/llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"groq/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
build_hf_repo_model_entry(
"groq/meta-llama/llama-4-maverick-17b-128e-instruct",
"meta-llama/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
] + SAFETY_MODELS_ENTRIES

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import GroqCompatConfig
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .groq import GroqCompatInferenceAdapter
adapter = GroqCompatInferenceAdapter(config)
return adapter

View file

@ -1,38 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class GroqProviderDataValidator(BaseModel):
groq_api_key: str | None = Field(
default=None,
description="API key for Groq models",
)
@json_schema_type
class GroqCompatConfig(BaseModel):
api_key: str | None = Field(
default=None,
description="The Groq API key",
)
openai_compat_api_base: str = Field(
default="https://api.groq.com/openai/v1",
description="The URL for the Groq API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.groq.com/openai/v1",
"api_key": api_key,
}

View file

@ -1,30 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.groq_openai_compat.config import GroqCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from ..groq.models import MODEL_ENTRIES
class GroqCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: GroqCompatConfig
def __init__(self, config: GroqCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="groq_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -5,55 +5,53 @@
# the root directory of this source tree.
import logging
from llama_api_client import AsyncLlamaAPIClient, NotFoundError
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .models import MODEL_ENTRIES
logger = logging.getLogger(__name__)
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
"""
Llama API Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability().
- OpenAIMixin.check_model_availability() queries the Llama API to check if a model exists
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
"""
_config: LlamaCompatConfig
def __init__(self, config: LlamaCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
litellm_provider_name="meta_llama",
api_key_from_config=config.api_key,
provider_data_api_key_field="llama_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
async def check_model_availability(self, model: str) -> bool:
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
"""
Check if a specific model is available from Llama API.
Get the base URL for OpenAI mixin.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
:return: The Llama API base URL
"""
try:
llama_api_client = self._get_llama_api_client()
retrieved_model = await llama_api_client.models.retrieve(model)
logger.info(f"Model {retrieved_model.id} is available from Llama API")
return True
except NotFoundError:
logger.error(f"Model {model} is not available from Llama API")
return False
except Exception as e:
logger.error(f"Failed to check model availability from Llama API: {e}")
return False
return self.config.openai_compat_api_base
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()
def _get_llama_api_client(self) -> AsyncLlamaAPIClient:
return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base)

View file

@ -7,9 +7,8 @@
import logging
import warnings
from collections.abc import AsyncIterator
from typing import Any
from openai import APIConnectionError, AsyncOpenAI, BadRequestError, NotFoundError
from openai import APIConnectionError, BadRequestError
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -28,12 +27,6 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -47,8 +40,8 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import (
convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
prepare_openai_completion_params,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
from . import NVIDIAConfig
@ -64,7 +57,20 @@ from .utils import _is_nvidia_hosted
logger = logging.getLogger(__name__)
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
"""
NVIDIA Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
ModelRegistryHelper to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability(). It also
must come before Inference to ensure that OpenAIMixin methods are available
in the Inference interface.
- OpenAIMixin.check_model_availability() queries the NVIDIA API to check if a model exists
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
"""
def __init__(self, config: NVIDIAConfig) -> None:
# TODO(mf): filter by available models
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
@ -88,45 +94,21 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
self._config = config
async def check_model_availability(self, model: str) -> bool:
def get_api_key(self) -> str:
"""
Check if a specific model is available.
Get the API key for OpenAI mixin.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
:return: The NVIDIA API key
"""
try:
await self._client.models.retrieve(model)
return True
except NotFoundError:
logger.error(f"Model {model} is not available")
except Exception as e:
logger.error(f"Failed to check model availability: {e}")
return False
return self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"
@property
def _client(self) -> AsyncOpenAI:
def get_base_url(self) -> str:
"""
Returns an OpenAI client for the configured NVIDIA API endpoint.
Get the base URL for OpenAI mixin.
:return: An OpenAI client
:return: The NVIDIA API base URL
"""
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
return AsyncOpenAI(
base_url=base_url,
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
timeout=self._config.timeout,
)
async def _get_provider_model_id(self, model_id: str) -> str:
if not self.model_store:
raise RuntimeError("Model store is not set")
model = await self.model_store.get_model(model_id)
if model is None:
raise ValueError(f"Model {model_id} is unknown")
return model.provider_model_id
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
async def completion(
self,
@ -160,7 +142,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
response = await self._client.completions.create(**request)
response = await self.client.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -213,7 +195,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
extra_body["input_type"] = task_type_options[task_type]
try:
response = await self._client.embeddings.create(
response = await self.client.embeddings.create(
model=provider_model_id,
input=input,
extra_body=extra_body,
@ -228,16 +210,6 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
#
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
@ -274,7 +246,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
response = await self._client.chat.completions.create(**request)
response = await self.client.chat.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -283,112 +255,3 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
else:
# we pass n=1 to get only one completion
return convert_openai_chat_completion_choice(response.choices[0])
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
provider_model_id = await self._get_provider_model_id(model)
params = await prepare_openai_completion_params(
model=provider_model_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
try:
return await self._client.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
provider_model_id = await self._get_provider_model_id(model)
params = await prepare_openai_completion_params(
model=provider_model_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
try:
return await self._client.chat.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e

View file

@ -13,8 +13,10 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(BaseModel):
url: str = DEFAULT_OLLAMA_URL
refresh_models: bool = Field(default=False, description="refresh and re-register models periodically")
refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models")
refresh_models: bool = Field(
default=False,
description="Whether to refresh models periodically",
)
@classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:

View file

@ -96,14 +96,16 @@ class OllamaInferenceAdapter(
def __init__(self, config: OllamaImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self.config = config
self._client = None
self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {}
self._openai_client = None
@property
def client(self) -> AsyncClient:
if self._client is None:
self._client = AsyncClient(host=self.config.url)
return self._client
# ollama client attaches itself to the current event loop (sadly?)
loop = asyncio.get_running_loop()
if loop not in self._clients:
self._clients[loop] = AsyncClient(host=self.config.url)
return self._clients[loop]
@property
def openai_client(self) -> AsyncOpenAI:
@ -119,59 +121,61 @@ class OllamaInferenceAdapter(
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
)
if self.config.refresh_models:
logger.debug("ollama starting background model refresh task")
self._refresh_task = asyncio.create_task(self._refresh_models())
def cb(task):
if task.cancelled():
import traceback
logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
logger.error(f"ollama background refresh task died: {task.exception()}")
else:
logger.error("ollama background refresh task completed unexpectedly")
self._refresh_task.add_done_callback(cb)
async def _refresh_models(self) -> None:
# Wait for model store to be available (with timeout)
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1
if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")
async def should_refresh_models(self) -> bool:
return self.config.refresh_models
async def list_models(self) -> list[Model] | None:
provider_id = self.__provider_id__
while True:
try:
response = await self.client.list()
except Exception as e:
logger.warning(f"Failed to list models: {str(e)}")
await asyncio.sleep(self.config.refresh_models_interval)
response = await self.client.list()
# always add the two embedding models which can be pulled on demand
models = [
Model(
identifier="all-minilm:l6-v2",
provider_resource_id="all-minilm:l6-v2",
provider_id=provider_id,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
model_type=ModelType.embedding,
),
# add all-minilm alias
Model(
identifier="all-minilm",
provider_resource_id="all-minilm:l6-v2",
provider_id=provider_id,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
model_type=ModelType.embedding,
),
Model(
identifier="nomic-embed-text",
provider_resource_id="nomic-embed-text",
provider_id=provider_id,
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
model_type=ModelType.embedding,
),
]
for m in response.models:
# kill embedding models since we don't know dimensions for them
if "bert" in m.details.family:
continue
models = []
for m in response.models:
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
if model_type == ModelType.embedding:
continue
models.append(
Model(
identifier=m.model,
provider_resource_id=m.model,
provider_id=provider_id,
metadata={},
model_type=model_type,
)
models.append(
Model(
identifier=m.model,
provider_resource_id=m.model,
provider_id=provider_id,
metadata={},
model_type=ModelType.llm,
)
await self.model_store.update_registered_llm_models(provider_id, models)
logger.debug(f"ollama refreshed model list ({len(models)} models)")
await asyncio.sleep(self.config.refresh_models_interval)
)
return models
async def health(self) -> HealthResponse:
"""
@ -223,12 +227,7 @@ class OllamaInferenceAdapter(
return available_models
async def shutdown(self) -> None:
if hasattr(self, "_refresh_task") and not self._refresh_task.done():
logger.debug("ollama cancelling background refresh task")
self._refresh_task.cancel()
self._client = None
self._openai_client = None
self._clients.clear()
async def unregister_model(self, model_id: str) -> None:
pass

View file

@ -24,9 +24,19 @@ class OpenAIConfig(BaseModel):
default=None,
description="API key for OpenAI models",
)
base_url: str = Field(
default="https://api.openai.com/v1",
description="Base URL for OpenAI API",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY}", **kwargs) -> dict[str, Any]:
def sample_run_config(
cls,
api_key: str = "${env.OPENAI_API_KEY:=}",
base_url: str = "${env.OPENAI_BASE_URL:=https://api.openai.com/v1}",
**kwargs,
) -> dict[str, Any]:
return {
"api_key": api_key,
"base_url": base_url,
}

View file

@ -12,11 +12,6 @@ from llama_stack.providers.utils.inference.model_registry import (
)
LLM_MODEL_IDS = [
# the models w/ "openai/" prefix are the litellm specific model names.
# they should be deprecated in favor of the canonical openai model names.
"openai/gpt-4o",
"openai/gpt-4o-mini",
"openai/chatgpt-4o-latest",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo",
"gpt-3.5-turbo-instruct",
@ -43,8 +38,6 @@ class EmbeddingModelInfo:
EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = {
"openai/text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
"openai/text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
"text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
"text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
}

View file

@ -5,23 +5,9 @@
# the root directory of this source tree.
import logging
from collections.abc import AsyncIterator
from typing import Any
from openai import AsyncOpenAI, NotFoundError
from llama_stack.apis.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import OpenAIConfig
from .models import MODEL_ENTRIES
@ -30,7 +16,7 @@ logger = logging.getLogger(__name__)
#
# This OpenAI adapter implements Inference methods using two clients -
# This OpenAI adapter implements Inference methods using two mixins -
#
# | Inference Method | Implementation Source |
# |----------------------------|--------------------------|
@ -39,15 +25,27 @@ logger = logging.getLogger(__name__)
# | embedding | LiteLLMOpenAIMixin |
# | batch_completion | LiteLLMOpenAIMixin |
# | batch_chat_completion | LiteLLMOpenAIMixin |
# | openai_completion | AsyncOpenAI |
# | openai_chat_completion | AsyncOpenAI |
# | openai_embeddings | AsyncOpenAI |
# | openai_completion | OpenAIMixin |
# | openai_chat_completion | OpenAIMixin |
# | openai_embeddings | OpenAIMixin |
#
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
"""
OpenAI Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability().
- OpenAIMixin.check_model_availability() queries the OpenAI API to check if a model exists
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
"""
def __init__(self, config: OpenAIConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
litellm_provider_name="openai",
api_key_from_config=config.api_key,
provider_data_api_key_field="openai_api_key",
)
@ -60,191 +58,19 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
# litellm specific model names, an abstraction leak.
self.is_openai_compat = True
async def check_model_availability(self, model: str) -> bool:
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
"""
Check if a specific model is available from OpenAI.
Get the OpenAI API base URL.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
Returns the OpenAI API base URL from the configuration.
"""
try:
openai_client = self._get_openai_client()
retrieved_model = await openai_client.models.retrieve(model)
logger.info(f"Model {retrieved_model.id} is available from OpenAI")
return True
except NotFoundError:
logger.error(f"Model {model} is not available from OpenAI")
return False
except Exception as e:
logger.error(f"Failed to check model availability from OpenAI: {e}")
return False
return self.config.base_url
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()
def _get_openai_client(self) -> AsyncOpenAI:
return AsyncOpenAI(
api_key=self.get_api_key(),
)
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
if guided_choice is not None:
logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
if prompt_logprobs is not None:
logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
model_id = (await self.model_store.get_model(model)).provider_resource_id
if model_id.startswith("openai/"):
model_id = model_id[len("openai/") :]
params = await prepare_openai_completion_params(
model=model_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
suffix=suffix,
)
return await self._get_openai_client().completions.create(**params)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_id = (await self.model_store.get_model(model)).provider_resource_id
if model_id.startswith("openai/"):
model_id = model_id[len("openai/") :]
params = await prepare_openai_completion_params(
model=model_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await self._get_openai_client().chat.completions.create(**params)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
model_id = (await self.model_store.get_model(model)).provider_resource_id
if model_id.startswith("openai/"):
model_id = model_id[len("openai/") :]
# Prepare parameters for OpenAI embeddings API
params = {
"model": model_id,
"input": input,
}
if encoding_format is not None:
params["encoding_format"] = encoding_format
if dimensions is not None:
params["dimensions"] = dimensions
if user is not None:
params["user"] = user
# Call OpenAI embeddings API
response = await self._get_openai_client().embeddings.create(**params)
data = []
for i, embedding_data in enumerate(response.data):
data.append(
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
return OpenAIEmbeddingsResponse(
data=data,
model=response.model,
usage=usage,
)

View file

@ -30,7 +30,7 @@ class SambaNovaImplConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.sambanova.ai/v1",
"api_key": api_key,

View file

@ -9,49 +9,20 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = [
build_hf_repo_model_entry(
"sambanova/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
]
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"sambanova/Meta-Llama-3.1-8B-Instruct",
"Meta-Llama-3.1-8B-Instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Meta-Llama-3.1-405B-Instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Meta-Llama-3.2-1B-Instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Meta-Llama-3.2-3B-Instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Meta-Llama-3.3-70B-Instruct",
"Meta-Llama-3.3-70B-Instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Llama-3.2-11B-Vision-Instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Llama-3.2-90B-Vision-Instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Llama-4-Scout-17B-16E-Instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Llama-4-Maverick-17B-128E-Instruct",
"Llama-4-Maverick-17B-128E-Instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
] + SAFETY_MODELS_ENTRIES

View file

@ -182,6 +182,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
litellm_provider_name="sambanova",
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
provider_data_api_key_field="sambanova_api_key",
)

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import SambaNovaCompatConfig
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .sambanova import SambaNovaCompatInferenceAdapter
adapter = SambaNovaCompatInferenceAdapter(config)
return adapter

View file

@ -1,38 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class SambaNovaProviderDataValidator(BaseModel):
sambanova_api_key: str | None = Field(
default=None,
description="API key for SambaNova models",
)
@json_schema_type
class SambaNovaCompatConfig(BaseModel):
api_key: str | None = Field(
default=None,
description="The SambaNova API key",
)
openai_compat_api_base: str = Field(
default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.sambanova.ai/v1",
"api_key": api_key,
}

View file

@ -1,30 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.sambanova_openai_compat.config import SambaNovaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from ..sambanova.models import MODEL_ENTRIES
class SambaNovaCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: SambaNovaCompatConfig
def __init__(self, config: SambaNovaCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="sambanova_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -19,7 +19,7 @@ class TGIImplConfig(BaseModel):
@classmethod
def sample_run_config(
cls,
url: str = "${env.TGI_URL}",
url: str = "${env.TGI_URL:=}",
**kwargs,
):
return {

View file

@ -305,6 +305,8 @@ class _HfAdapter(
class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None:
if not config.url:
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
log.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient(
model=config.url,

View file

@ -6,13 +6,14 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class TogetherImplConfig(BaseModel):
class TogetherImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.together.xyz/v1",
description="The URL for the Together AI server",
@ -26,5 +27,5 @@ class TogetherImplConfig(BaseModel):
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"url": "https://api.together.xyz/v1",
"api_key": "${env.TOGETHER_API_KEY}",
"api_key": "${env.TOGETHER_API_KEY:=}",
}

View file

@ -69,15 +69,9 @@ MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
additional_aliases=[
"together/meta-llama/Llama-4-Scout-17B-16E-Instruct",
],
),
build_hf_repo_model_entry(
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
additional_aliases=[
"together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
],
),
] + SAFETY_MODELS_ENTRIES

View file

@ -66,7 +66,7 @@ logger = get_logger(name=__name__, category="inference")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
self.config = config
async def initialize(self) -> None:

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import TogetherCompatConfig
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .together import TogetherCompatInferenceAdapter
adapter = TogetherCompatInferenceAdapter(config)
return adapter

View file

@ -1,38 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class TogetherProviderDataValidator(BaseModel):
together_api_key: str | None = Field(
default=None,
description="API key for Together models",
)
@json_schema_type
class TogetherCompatConfig(BaseModel):
api_key: str | None = Field(
default=None,
description="The Together API key",
)
openai_compat_api_base: str = Field(
default="https://api.together.xyz/v1",
description="The URL for the Together API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.TOGETHER_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.together.xyz/v1",
"api_key": api_key,
}

View file

@ -1,30 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.together_openai_compat.config import TogetherCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from ..together.models import MODEL_ENTRIES
class TogetherCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: TogetherCompatConfig
def __init__(self, config: TogetherCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="together_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -33,10 +33,6 @@ class VLLMInferenceAdapterConfig(BaseModel):
default=False,
description="Whether to refresh models periodically",
)
refresh_models_interval: int = Field(
default=300,
description="Interval in seconds to refresh models",
)
@field_validator("tls_verify")
@classmethod

View file

@ -3,7 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
@ -293,7 +292,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
model_store: ModelStore | None = None
_refresh_task: asyncio.Task | None = None
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
@ -302,64 +300,32 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def initialize(self) -> None:
if not self.config.url:
# intentionally don't raise an error here, we want to allow the provider to be "dormant"
# or available in distributions like "starter" without causing a ruckus
return
raise ValueError(
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
)
if self.config.refresh_models:
self._refresh_task = asyncio.create_task(self._refresh_models())
def cb(task):
import traceback
if task.cancelled():
log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
# print the stack trace for the exception
exc = task.exception()
log.error(f"vLLM background refresh task died: {exc}")
traceback.print_exception(exc)
else:
log.error("vLLM background refresh task completed unexpectedly")
self._refresh_task.add_done_callback(cb)
async def _refresh_models(self) -> None:
provider_id = self.__provider_id__
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1
if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")
async def should_refresh_models(self) -> bool:
return self.config.refresh_models
async def list_models(self) -> list[Model] | None:
self._lazy_initialize_client()
assert self.client is not None # mypy
while True:
try:
models = []
async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
models.append(
Model(
identifier=m.id,
provider_resource_id=m.id,
provider_id=provider_id,
metadata={},
model_type=model_type,
)
)
await self.model_store.update_registered_llm_models(provider_id, models)
log.debug(f"vLLM refreshed model list ({len(models)} models)")
except Exception as e:
log.error(f"vLLM background refresh task failed: {e}")
await asyncio.sleep(self.config.refresh_models_interval)
models = []
async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
models.append(
Model(
identifier=m.id,
provider_resource_id=m.id,
provider_id=self.__provider_id__,
metadata={},
model_type=model_type,
)
)
return models
async def shutdown(self) -> None:
if self._refresh_task:
self._refresh_task.cancel()
self._refresh_task = None
pass
async def unregister_model(self, model_id: str) -> None:
pass
@ -374,9 +340,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
HealthResponse: A dictionary containing the health status.
"""
try:
if not self.config.url:
return HealthResponse(status=HealthStatus.ERROR, message="vLLM URL is not set")
client = self._create_client() if self.client is None else self.client
_ = [m async for m in client.models.list()] # Ensure the client is initialized
return HealthResponse(status=HealthStatus.OK)
@ -392,11 +355,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if self.client is not None:
return
if not self.config.url:
raise ValueError(
"You must provide a vLLM URL in the run.yaml file (or set the VLLM_URL environment variable)"
)
log.info(f"Initializing vLLM client with base_url={self.config.url}")
self.client = self._create_client()

View file

@ -30,7 +30,7 @@ class SambaNovaSafetyConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.sambanova.ai/v1",
"api_key": api_key,

View file

@ -12,6 +12,6 @@ from .config import ChromaVectorIOConfig
async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .chroma import ChromaVectorIOAdapter
impl = ChromaVectorIOAdapter(config, deps[Api.inference])
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -12,25 +12,19 @@ from urllib.parse import urlparse
import chromadb
from numpy.typing import NDArray
from llama_stack.apis.files import Files
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
SearchRankingOptions,
VectorIO,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
VectorStoreListFilesResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
@ -42,6 +36,13 @@ log = logging.getLogger(__name__)
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:chroma:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:chroma:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:chroma:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:chroma:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:chroma:{VERSION}::"
# this is a helper to allow us to use async and non-async chroma clients interchangeably
async def maybe_await(result):
@ -51,16 +52,20 @@ async def maybe_await(result):
class ChromaIndex(EmbeddingIndex):
def __init__(self, client: ChromaClientType, collection):
def __init__(self, client: ChromaClientType, collection, kvstore: KVStore | None = None):
self.client = client
self.collection = collection
self.kvstore = kvstore
async def initialize(self):
pass
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
ids = [f"{c.metadata['document_id']}:chunk-{i}" for i, c in enumerate(chunks)]
ids = [f"{c.metadata.get('document_id', '')}:{c.chunk_id}" for c in chunks]
await maybe_await(
self.collection.add(
documents=[chunk.model_dump_json() for chunk in chunks],
@ -110,6 +115,9 @@ class ChromaIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Chroma")
async def delete_chunk(self, chunk_id: str) -> None:
raise NotImplementedError("delete_chunk is not supported in Chroma")
async def query_hybrid(
self,
embedding: NDArray,
@ -122,24 +130,26 @@ class ChromaIndex(EmbeddingIndex):
raise NotImplementedError("Hybrid search is not supported in Chroma")
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(
self,
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
inference_api: Api.inference,
files_api: Files | None,
) -> None:
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
self.config = config
self.inference_api = inference_api
self.client = None
self.cache = {}
self.kvstore: KVStore | None = None
self.vector_db_store = None
async def initialize(self) -> None:
if isinstance(self.config, RemoteChromaVectorIOConfig):
if not self.config.url:
raise ValueError("URL is a required parameter for the remote Chroma provider's config")
self.kvstore = await kvstore_impl(self.config.kvstore)
self.vector_db_store = self.kvstore
if isinstance(self.config, RemoteChromaVectorIOConfig):
log.info(f"Connecting to Chroma server at: {self.config.url}")
url = self.config.url.rstrip("/")
parsed = urlparse(url)
@ -151,6 +161,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
else:
log.info(f"Connecting to Chroma local db at: {self.config.db_path}")
self.client = chromadb.PersistentClient(path=self.config.db_path)
self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None:
pass
@ -170,6 +181,10 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
)
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id not in self.cache:
log.warning(f"Vector DB {vector_db_id} not found")
return
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
@ -180,6 +195,8 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
ttl_seconds: int | None = None,
) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if index is None:
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
await index.insert_chunks(chunks)
@ -191,6 +208,9 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if index is None:
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
return await index.query_chunks(query, params)
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
@ -207,106 +227,5 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
self.cache[vector_db_id] = index
return index
async def openai_create_vector_store(
self,
name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_list_vector_stores(
self,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> VectorStoreListFilesResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")

View file

@ -6,12 +6,23 @@
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class ChromaVectorIOConfig(BaseModel):
url: str | None
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
@classmethod
def sample_run_config(cls, url: str = "${env.CHROMADB_URL}", **kwargs: Any) -> dict[str, Any]:
return {"url": url}
def sample_run_config(cls, __distro_dir__: str, url: str = "${env.CHROMADB_URL}", **kwargs: Any) -> dict[str, Any]:
return {
"url": url,
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="chroma_remote_registry.db",
),
}

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import json
import logging
import os
import re
@ -248,6 +247,16 @@ class MilvusIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Milvus")
async def delete_chunk(self, chunk_id: str) -> None:
"""Remove a chunk from the Milvus collection."""
try:
await asyncio.to_thread(
self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"'
)
except Exception as e:
logger.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}")
raise
class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(
@ -371,185 +380,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
return await index.query_chunks(query, params)
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file metadata to Milvus database."""
if store_id not in self.openai_vector_stores:
store_info = await self._load_openai_vector_stores(store_id)
if not store_info:
logger.error(f"OpenAI vector store {store_id} not found")
raise ValueError(f"No vector store found with id {store_id}")
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
"""Delete a chunk from a milvus vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")
try:
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
file_schema = MilvusClient.create_schema(
auto_id=False,
enable_dynamic_field=True,
description="Metadata for OpenAI vector store files",
)
file_schema.add_field(
field_name="store_file_id", datatype=DataType.VARCHAR, is_primary=True, max_length=512
)
file_schema.add_field(field_name="store_id", datatype=DataType.VARCHAR, max_length=512)
file_schema.add_field(field_name="file_id", datatype=DataType.VARCHAR, max_length=512)
file_schema.add_field(field_name="file_info", datatype=DataType.VARCHAR, max_length=65535)
await asyncio.to_thread(
self.client.create_collection,
collection_name="openai_vector_store_files",
schema=file_schema,
)
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"):
content_schema = MilvusClient.create_schema(
auto_id=False,
enable_dynamic_field=True,
description="Contents for OpenAI vector store files",
)
content_schema.add_field(
field_name="chunk_id", datatype=DataType.VARCHAR, is_primary=True, max_length=1024
)
content_schema.add_field(field_name="store_file_id", datatype=DataType.VARCHAR, max_length=1024)
content_schema.add_field(field_name="store_id", datatype=DataType.VARCHAR, max_length=512)
content_schema.add_field(field_name="file_id", datatype=DataType.VARCHAR, max_length=512)
content_schema.add_field(field_name="content", datatype=DataType.VARCHAR, max_length=65535)
await asyncio.to_thread(
self.client.create_collection,
collection_name="openai_vector_store_files_contents",
schema=content_schema,
)
file_data = [
{
"store_file_id": f"{store_id}_{file_id}",
"store_id": store_id,
"file_id": file_id,
"file_info": json.dumps(file_info),
}
]
await asyncio.to_thread(
self.client.upsert,
collection_name="openai_vector_store_files",
data=file_data,
)
# Save file contents
contents_data = [
{
"chunk_id": content.get("chunk_metadata").get("chunk_id"),
"store_file_id": f"{store_id}_{file_id}",
"store_id": store_id,
"file_id": file_id,
"content": json.dumps(content),
}
for content in file_contents
]
await asyncio.to_thread(
self.client.upsert,
collection_name="openai_vector_store_files_contents",
data=contents_data,
)
except Exception as e:
logger.error(f"Error saving openai vector store file {file_id} for store {store_id}: {e}")
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
"""Load vector store file metadata from Milvus database."""
try:
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
return {}
query_filter = f"store_file_id == '{store_id}_{file_id}'"
results = await asyncio.to_thread(
self.client.query,
collection_name="openai_vector_store_files",
filter=query_filter,
output_fields=["file_info"],
)
if results:
try:
return json.loads(results[0]["file_info"])
except json.JSONDecodeError as e:
logger.error(f"Failed to decode file_info for store {store_id}, file {file_id}: {e}")
return {}
return {}
except Exception as e:
logger.error(f"Error loading openai vector store file {file_id} for store {store_id}: {e}")
return {}
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
"""Update vector store file metadata in Milvus database."""
try:
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
return
file_data = [
{
"store_file_id": f"{store_id}_{file_id}",
"store_id": store_id,
"file_id": file_id,
"file_info": json.dumps(file_info),
}
]
await asyncio.to_thread(
self.client.upsert,
collection_name="openai_vector_store_files",
data=file_data,
)
except Exception as e:
logger.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}")
raise
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
"""Load vector store file contents from Milvus database."""
try:
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"):
return []
query_filter = (
f"store_id == '{store_id}' AND file_id == '{file_id}' AND store_file_id == '{store_id}_{file_id}'"
)
results = await asyncio.to_thread(
self.client.query,
collection_name="openai_vector_store_files_contents",
filter=query_filter,
output_fields=["chunk_id", "store_id", "file_id", "content"],
)
contents = []
for result in results:
try:
content = json.loads(result["content"])
contents.append(content)
except json.JSONDecodeError as e:
logger.error(f"Failed to decode content for store {store_id}, file {file_id}: {e}")
return contents
except Exception as e:
logger.error(f"Error loading openai vector store file contents for {file_id} in store {store_id}: {e}")
return []
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store file metadata from Milvus database."""
try:
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
return
query_filter = f"store_file_id in ['{store_id}_{file_id}']"
await asyncio.to_thread(
self.client.delete,
collection_name="openai_vector_store_files",
filter=query_filter,
)
if await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"):
await asyncio.to_thread(
self.client.delete,
collection_name="openai_vector_store_files_contents",
filter=query_filter,
)
except Exception as e:
logger.error(f"Error deleting openai vector store file {file_id} for store {store_id}: {e}")
raise
for chunk_id in chunk_ids:
# Use the index's delete_chunk method
await index.index.delete_chunk(chunk_id)

View file

@ -12,6 +12,6 @@ from .config import PGVectorVectorIOConfig
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .pgvector import PGVectorVectorIOAdapter
impl = PGVectorVectorIOAdapter(config, deps[Api.inference])
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
await impl.initialize()
return impl

View file

@ -99,7 +99,7 @@ class PGVectorIndex(EmbeddingIndex):
for i, chunk in enumerate(chunks):
values.append(
(
f"{chunk.metadata['document_id']}:chunk-{i}",
f"{chunk.chunk_id}",
Json(chunk.model_dump()),
embeddings[i].tolist(),
)
@ -159,6 +159,11 @@ class PGVectorIndex(EmbeddingIndex):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
async def delete_chunk(self, chunk_id: str) -> None:
"""Remove a chunk from the PostgreSQL table."""
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id = %s", (chunk_id,))
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(
@ -266,124 +271,12 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
return self.cache[vector_db_id]
# OpenAI Vector Stores File operations are not supported in PGVector
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file metadata to Postgres database."""
if self.conn is None:
raise RuntimeError("PostgreSQL connection is not initialized")
try:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(
"""
CREATE TABLE IF NOT EXISTS openai_vector_store_files (
store_id TEXT,
file_id TEXT,
metadata JSONB,
PRIMARY KEY (store_id, file_id)
)
"""
)
cur.execute(
"""
CREATE TABLE IF NOT EXISTS openai_vector_store_files_contents (
store_id TEXT,
file_id TEXT,
contents JSONB,
PRIMARY KEY (store_id, file_id)
)
"""
)
# Insert file metadata
files_query = sql.SQL(
"""
INSERT INTO openai_vector_store_files (store_id, file_id, metadata)
VALUES %s
ON CONFLICT (store_id, file_id) DO UPDATE SET metadata = EXCLUDED.metadata
"""
)
files_values = [(store_id, file_id, Json(file_info))]
execute_values(cur, files_query, files_values, template="(%s, %s, %s)")
# Insert file contents
contents_query = sql.SQL(
"""
INSERT INTO openai_vector_store_files_contents (store_id, file_id, contents)
VALUES %s
ON CONFLICT (store_id, file_id) DO UPDATE SET contents = EXCLUDED.contents
"""
)
contents_values = [(store_id, file_id, Json(file_contents))]
execute_values(cur, contents_query, contents_values, template="(%s, %s, %s)")
except Exception as e:
log.error(f"Error saving openai vector store file {file_id} for store {store_id}: {e}")
raise
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
"""Delete a chunk from a PostgreSQL vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
"""Load vector store file metadata from Postgres database."""
if self.conn is None:
raise RuntimeError("PostgreSQL connection is not initialized")
try:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(
"SELECT metadata FROM openai_vector_store_files WHERE store_id = %s AND file_id = %s",
(store_id, file_id),
)
row = cur.fetchone()
return row[0] if row and row[0] is not None else {}
except Exception as e:
log.error(f"Error loading openai vector store file {file_id} for store {store_id}: {e}")
return {}
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
"""Load vector store file contents from Postgres database."""
if self.conn is None:
raise RuntimeError("PostgreSQL connection is not initialized")
try:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(
"SELECT contents FROM openai_vector_store_files_contents WHERE store_id = %s AND file_id = %s",
(store_id, file_id),
)
row = cur.fetchone()
return row[0] if row and row[0] is not None else []
except Exception as e:
log.error(f"Error loading openai vector store file contents for {file_id} in store {store_id}: {e}")
return []
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
"""Update vector store file metadata in Postgres database."""
if self.conn is None:
raise RuntimeError("PostgreSQL connection is not initialized")
try:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
query = sql.SQL(
"""
INSERT INTO openai_vector_store_files (store_id, file_id, metadata)
VALUES %s
ON CONFLICT (store_id, file_id) DO UPDATE SET metadata = EXCLUDED.metadata
"""
)
values = [(store_id, file_id, Json(file_info))]
execute_values(cur, query, values, template="(%s, %s, %s)")
except Exception as e:
log.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}")
raise
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store file metadata from Postgres database."""
if self.conn is None:
raise RuntimeError("PostgreSQL connection is not initialized")
try:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(
"DELETE FROM openai_vector_store_files WHERE store_id = %s AND file_id = %s",
(store_id, file_id),
)
cur.execute(
"DELETE FROM openai_vector_store_files_contents WHERE store_id = %s AND file_id = %s",
(store_id, file_id),
)
except Exception as e:
log.error(f"Error deleting openai vector store file {file_id} for store {store_id}: {e}")
raise
for chunk_id in chunk_ids:
# Use the index's delete_chunk method
await index.index.delete_chunk(chunk_id)

View file

@ -82,6 +82,9 @@ class QdrantIndex(EmbeddingIndex):
await self.client.upsert(collection_name=self.collection_name, points=points)
async def delete_chunk(self, chunk_id: str) -> None:
raise NotImplementedError("delete_chunk is not supported in qdrant")
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
results = (
await self.client.query_points(
@ -307,3 +310,6 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
file_id: str,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")

View file

@ -66,6 +66,9 @@ class WeaviateIndex(EmbeddingIndex):
# TODO: make this async friendly
collection.data.insert_many(data_objects)
async def delete_chunk(self, chunk_id: str) -> None:
raise NotImplementedError("delete_chunk is not supported in Chroma")
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
collection = self.client.collections.get(self.collection_name)
@ -264,3 +267,6 @@ class WeaviateVectorIOAdapter(
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")

View file

@ -88,7 +88,7 @@ class SentenceTransformerEmbeddingMixin:
usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.provider_resource_id,
model=model,
usage=usage,
)

View file

@ -68,11 +68,23 @@ class LiteLLMOpenAIMixin(
def __init__(
self,
model_entries,
litellm_provider_name: str,
api_key_from_config: str | None,
provider_data_api_key_field: str,
openai_compat_api_base: str | None = None,
):
"""
Initialize the LiteLLMOpenAIMixin.
:param model_entries: The model entries to register.
:param api_key_from_config: The API key to use from the config.
:param provider_data_api_key_field: The field in the provider data that contains the API key.
:param litellm_provider_name: The name of the provider, used for model lookups.
:param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility.
"""
ModelRegistryHelper.__init__(self, model_entries)
self.litellm_provider_name = litellm_provider_name
self.api_key_from_config = api_key_from_config
self.provider_data_api_key_field = provider_data_api_key_field
self.api_base = openai_compat_api_base
@ -91,7 +103,11 @@ class LiteLLMOpenAIMixin(
def get_litellm_model_name(self, model_id: str) -> str:
# users may be using openai/ prefix in their model names. the openai/models.py did this by default.
# model_id.startswith("openai/") is for backwards compatibility.
return "openai/" + model_id if self.is_openai_compat and not model_id.startswith("openai/") else model_id
return (
f"{self.litellm_provider_name}/{model_id}"
if self.is_openai_compat and not model_id.startswith(self.litellm_provider_name)
else model_id
)
async def completion(
self,
@ -421,3 +437,17 @@ class LiteLLMOpenAIMixin(
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available via LiteLLM for the current
provider (self.litellm_provider_name).
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
if self.litellm_provider_name not in litellm.models_by_provider:
logger.error(f"Provider {self.litellm_provider_name} is not registered in litellm.")
return False
return model in litellm.models_by_provider[self.litellm_provider_name]

View file

@ -10,12 +10,22 @@ from pydantic import BaseModel, Field
from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
)
logger = get_logger(name=__name__, category="core")
class RemoteInferenceProviderConfig(BaseModel):
allowed_models: list[str] | None = Field(
default=None,
description="List of models that should be registered with the model registry. If None, all models are allowed.",
)
# TODO: this class is more confusing than useful right now. We need to make it
# more closer to the Model class.
@ -40,7 +50,8 @@ def build_hf_repo_model_entry(
additional_aliases: list[str] | None = None,
) -> ProviderModelEntry:
aliases = [
get_huggingface_repo(model_descriptor),
# NOTE: avoid HF aliases because they _cannot_ be unique across providers
# get_huggingface_repo(model_descriptor),
]
if additional_aliases:
aliases.extend(additional_aliases)
@ -62,7 +73,12 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider
class ModelRegistryHelper(ModelsProtocolPrivate):
def __init__(self, model_entries: list[ProviderModelEntry]):
__provider_id__: str
def __init__(self, model_entries: list[ProviderModelEntry], allowed_models: list[str] | None = None):
self.model_entries = model_entries
self.allowed_models = allowed_models
self.alias_to_provider_id_map = {}
self.provider_id_to_llama_model_map = {}
for entry in model_entries:
@ -76,6 +92,27 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id
self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model
async def list_models(self) -> list[Model] | None:
models = []
for entry in self.model_entries:
ids = [entry.provider_model_id] + entry.aliases
for id in ids:
if self.allowed_models and id not in self.allowed_models:
continue
models.append(
Model(
identifier=id,
provider_resource_id=entry.provider_model_id,
model_type=ModelType.llm,
metadata=entry.metadata,
provider_id=self.__provider_id__,
)
)
return models
async def should_refresh_models(self) -> bool:
return False
def get_provider_model_id(self, identifier: str) -> str | None:
return self.alias_to_provider_id_map.get(identifier, None)
@ -98,6 +135,9 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
logger.info(
f"check_model_availability is not implemented for {self.__class__.__name__}. Returning False by default."
)
return False
async def register_model(self, model: Model) -> Model:
@ -148,8 +188,8 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
return model
async def unregister_model(self, model_id: str) -> None:
# TODO: should we block unregistering base supported provider model IDs?
if model_id not in self.alias_to_provider_id_map:
raise ValueError(f"Model id '{model_id}' is not registered.")
del self.alias_to_provider_id_map[model_id]
# model_id is the identifier, not the provider_resource_id
# unfortunately, this ID can be of the form provider_id/model_id which
# we never registered. TODO: fix this by significantly rewriting
# registration and registry helper
pass

View file

@ -0,0 +1,272 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import Any
import openai
from openai import NOT_GIVEN, AsyncOpenAI
from llama_stack.apis.inference import (
Model,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
logger = get_logger(name=__name__, category="core")
class OpenAIMixin(ABC):
"""
Mixin class that provides OpenAI-specific functionality for inference providers.
This class handles direct OpenAI API calls using the AsyncOpenAI client.
This is an abstract base class that requires child classes to implement:
- get_api_key(): Method to retrieve the API key
- get_base_url(): Method to retrieve the OpenAI-compatible API base URL
Expected Dependencies:
- self.model_store: Injected by the Llama Stack distribution system at runtime.
This provides model registry functionality for looking up registered models.
The model_store is set in routing_tables/common.py during provider initialization.
"""
@abstractmethod
def get_api_key(self) -> str:
"""
Get the API key.
This method must be implemented by child classes to provide the API key
for authenticating with the OpenAI API or compatible endpoints.
:return: The API key as a string
"""
pass
@abstractmethod
def get_base_url(self) -> str:
"""
Get the OpenAI-compatible API base URL.
This method must be implemented by child classes to provide the base URL
for the OpenAI API or compatible endpoints (e.g., "https://api.openai.com/v1").
:return: The base URL as a string
"""
pass
@property
def client(self) -> AsyncOpenAI:
"""
Get an AsyncOpenAI client instance.
Uses the abstract methods get_api_key() and get_base_url() which must be
implemented by child classes.
"""
return AsyncOpenAI(
api_key=self.get_api_key(),
base_url=self.get_base_url(),
)
async def _get_provider_model_id(self, model: str) -> str:
"""
Get the provider-specific model ID from the model store.
This is a utility method that looks up the registered model and returns
the provider_resource_id that should be used for actual API calls.
:param model: The registered model name/identifier
:return: The provider-specific model ID (e.g., "gpt-4")
"""
# Look up the registered model to get the provider-specific model ID
# self.model_store is injected by the distribution system at runtime
model_obj: Model = await self.model_store.get_model(model) # type: ignore[attr-defined]
# provider_resource_id is str | None, but we expect it to be str for OpenAI calls
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {model} has no provider_resource_id")
return model_obj.provider_resource_id
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
"""
Direct OpenAI completion API call.
"""
if guided_choice is not None:
logger.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
if prompt_logprobs is not None:
logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
# TODO: fix openai_completion to return type compatible with OpenAI's API response
return await self.client.completions.create( # type: ignore[no-any-return]
**await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
suffix=suffix,
)
)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""
Direct OpenAI chat completion API call.
"""
# Type ignore because return types are compatible
return await self.client.chat.completions.create( # type: ignore[no-any-return]
**await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
"""
Direct OpenAI embeddings API call.
"""
# Call OpenAI embeddings API with properly typed parameters
response = await self.client.embeddings.create(
model=await self._get_provider_model_id(model),
input=input,
encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN,
dimensions=dimensions if dimensions is not None else NOT_GIVEN,
user=user if user is not None else NOT_GIVEN,
)
data = []
for i, embedding_data in enumerate(response.data):
data.append(
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
return OpenAIEmbeddingsResponse(
data=data,
model=response.model,
usage=usage,
)
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from OpenAI.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
try:
# Direct model lookup - returns model or raises NotFoundError
await self.client.models.retrieve(model)
return True
except openai.NotFoundError:
# Model doesn't exist - this is expected for unavailable models
pass
except Exception as e:
# All other errors (auth, rate limit, network, etc.)
logger.warning(f"Failed to check model availability for {model}: {e}")
return False

View file

@ -66,7 +66,7 @@ class OpenAIVectorStoreMixin(ABC):
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to persistent storage."""
assert self.kvstore is not None
assert self.kvstore
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
# update in-memory cache
@ -74,7 +74,7 @@ class OpenAIVectorStoreMixin(ABC):
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from persistent storage."""
assert self.kvstore is not None
assert self.kvstore
start_key = OPENAI_VECTOR_STORES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
stored_data = await self.kvstore.values_in_range(start_key, end_key)
@ -87,7 +87,7 @@ class OpenAIVectorStoreMixin(ABC):
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in persistent storage."""
assert self.kvstore is not None
assert self.kvstore
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
# update in-memory cache
@ -95,37 +95,66 @@ class OpenAIVectorStoreMixin(ABC):
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from persistent storage."""
assert self.kvstore is not None
assert self.kvstore
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.delete(key)
# remove from in-memory cache
self.openai_vector_stores.pop(store_id, None)
@abstractmethod
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file metadata to persistent storage."""
pass
assert self.kvstore
meta_key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.set(key=meta_key, value=json.dumps(file_info))
contents_prefix = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}:"
for idx, chunk in enumerate(file_contents):
await self.kvstore.set(key=f"{contents_prefix}{idx}", value=json.dumps(chunk))
@abstractmethod
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
"""Load vector store file metadata from persistent storage."""
pass
assert self.kvstore
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
stored_data = await self.kvstore.get(key)
return json.loads(stored_data) if stored_data else {}
@abstractmethod
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
"""Load vector store file contents from persistent storage."""
pass
assert self.kvstore
prefix = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}:"
end_key = f"{prefix}\xff"
raw_items = await self.kvstore.values_in_range(prefix, end_key)
return [json.loads(item) for item in raw_items]
@abstractmethod
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
"""Update vector store file metadata in persistent storage."""
pass
assert self.kvstore
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.set(key=key, value=json.dumps(file_info))
@abstractmethod
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store file metadata from persistent storage."""
assert self.kvstore
meta_key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.delete(meta_key)
contents_prefix = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}:"
end_key = f"{contents_prefix}\xff"
# load all stored chunk values (values_in_range is implemented by all backends)
raw_items = await self.kvstore.values_in_range(contents_prefix, end_key)
# delete each chunk by its index suffix
for idx in range(len(raw_items)):
await self.kvstore.delete(f"{contents_prefix}{idx}")
async def initialize_openai_vector_stores(self) -> None:
"""Load existing OpenAI vector stores into the in-memory cache."""
self.openai_vector_stores = await self._load_openai_vector_stores()
@abstractmethod
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
"""Delete a chunk from a vector store."""
pass
@abstractmethod
@ -138,10 +167,6 @@ class OpenAIVectorStoreMixin(ABC):
"""Unregister a vector database (provider-specific implementation)."""
pass
async def initialize_openai_vector_stores(self) -> None:
"""Load existing OpenAI vector stores into the in-memory cache."""
self.openai_vector_stores = await self._load_openai_vector_stores()
@abstractmethod
async def insert_chunks(
self,
@ -161,7 +186,7 @@ class OpenAIVectorStoreMixin(ABC):
async def openai_create_vector_store(
self,
name: str,
name: str | None = None,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
@ -743,17 +768,15 @@ class OpenAIVectorStoreMixin(ABC):
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id)
chunks = [Chunk.model_validate(c) for c in dict_chunks]
await self.delete_chunks(vector_store_id, [str(c.chunk_id) for c in chunks if c.chunk_id])
store_info = self.openai_vector_stores[vector_store_id].copy()
file = await self.openai_retrieve_vector_store_file(vector_store_id, file_id)
await self._delete_openai_vector_store_file_from_storage(vector_store_id, file_id)
# TODO: We need to actually delete the embeddings from the underlying vector store...
# Also uncomment the corresponding integration test marked as xfail
#
# test_openai_vector_store_delete_file_removes_from_vector_store in
# tests/integration/vector_io/test_openai_vector_stores.py
# Update in-memory cache
store_info["file_ids"].remove(file_id)
store_info["file_counts"][file.status] -= 1

View file

@ -231,6 +231,10 @@ class EmbeddingIndex(ABC):
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
raise NotImplementedError()
@abstractmethod
async def delete_chunk(self, chunk_id: str):
raise NotImplementedError()
@abstractmethod
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError()

View file

@ -83,6 +83,7 @@ class SQLiteTraceStore(TraceStore):
)
SELECT DISTINCT trace_id, root_span_id, start_time, end_time
FROM filtered_traces
WHERE root_span_id IS NOT NULL
LIMIT {limit} OFFSET {offset}
"""
@ -166,7 +167,11 @@ class SQLiteTraceStore(TraceStore):
return spans_by_id
async def get_trace(self, trace_id: str) -> Trace:
query = "SELECT * FROM traces WHERE trace_id = ?"
query = """
SELECT *
FROM traces t
WHERE t.trace_id = ?
"""
async with aiosqlite.connect(self.conn_string) as conn:
conn.row_factory = aiosqlite.Row
async with conn.execute(query, (trace_id,)) as cursor:

View file

@ -4,13 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from enum import Enum
from typing import Any, cast
import httpx
from mcp import ClientSession
from mcp import ClientSession, McpError
from mcp import types as mcp_types
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem
from llama_stack.apis.tools import (
@ -21,31 +24,61 @@ from llama_stack.apis.tools import (
)
from llama_stack.distribution.datatypes import AuthenticationRequiredError
from llama_stack.log import get_logger
from llama_stack.providers.utils.tools.ttl_dict import TTLDict
logger = get_logger(__name__, category="tools")
protocol_cache = TTLDict(ttl_seconds=3600)
class MCPProtol(Enum):
UNKNOWN = 0
STREAMABLE_HTTP = 1
SSE = 2
@asynccontextmanager
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
try:
async with sse_client(endpoint, headers=headers) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
yield session
except* httpx.HTTPStatusError as eg:
for exc in eg.exceptions:
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
err = cast(httpx.HTTPStatusError, exc)
if err.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
raise
async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
# we use a ttl'd dict to cache the happy path protocol for each endpoint
# but, we always fall back to trying the other protocol if we cannot initialize the session
connection_strategies = [MCPProtol.STREAMABLE_HTTP, MCPProtol.SSE]
mcp_protocol = protocol_cache.get(endpoint, default=MCPProtol.UNKNOWN)
if mcp_protocol == MCPProtol.SSE:
connection_strategies = [MCPProtol.SSE, MCPProtol.STREAMABLE_HTTP]
for i, strategy in enumerate(connection_strategies):
try:
client = streamablehttp_client
if strategy == MCPProtol.SSE:
client = sse_client
async with client(endpoint, headers=headers) as client_streams:
async with ClientSession(read_stream=client_streams[0], write_stream=client_streams[1]) as session:
await session.initialize()
protocol_cache[endpoint] = strategy
yield session
return
except* httpx.HTTPStatusError as eg:
for exc in eg.exceptions:
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
err = cast(httpx.HTTPStatusError, exc)
if err.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
if i == len(connection_strategies) - 1:
raise
except* McpError:
if i < len(connection_strategies) - 1:
logger.warning(
f"failed to connect via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
)
else:
raise
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
tools = []
async with sse_client_wrapper(endpoint, headers) as session:
async with client_wrapper(endpoint, headers) as session:
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
@ -73,7 +106,7 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
async def invoke_mcp_tool(
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
) -> ToolInvocationResult:
async with sse_client_wrapper(endpoint, headers) as session:
async with client_wrapper(endpoint, headers) as session:
result = await session.call_tool(tool_name, kwargs)
content: list[InterleavedContentItem] = []

View file

@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
from threading import RLock
from typing import Any
class TTLDict(dict):
"""
A dictionary with a ttl for each item
"""
def __init__(self, ttl_seconds: float, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ttl_seconds = ttl_seconds
self._expires: dict[Any, Any] = {} # expires holds when an item will expire
self._lock = RLock()
if args or kwargs:
for k, v in self.items():
self.__setitem__(k, v)
def __delitem__(self, key):
with self._lock:
del self._expires[key]
super().__delitem__(key)
def __setitem__(self, key, value):
with self._lock:
self._expires[key] = time.monotonic() + self.ttl_seconds
super().__setitem__(key, value)
def _is_expired(self, key):
if key not in self._expires:
return False
return time.monotonic() > self._expires[key]
def __getitem__(self, key):
with self._lock:
if self._is_expired(key):
del self._expires[key]
super().__delitem__(key)
raise KeyError(f"{key} has expired and was removed")
return super().__getitem__(key)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def __contains__(self, key):
try:
_ = self[key]
return True
except KeyError:
return False
def __repr__(self):
with self._lock:
for key in self.keys():
if self._is_expired(key):
del self._expires[key]
super().__delitem__(key)
return f"TTLDict({self.ttl_seconds}, {super().__repr__()})"