Merge branch 'main' into dead_code_removal

This commit is contained in:
Omar Abdelwahab 2025-10-06 13:21:36 -07:00 committed by GitHub
commit 9886520b40
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
927 changed files with 171924 additions and 102933 deletions

View file

@ -6,10 +6,12 @@
import os
from pydantic import BaseModel, Field
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
class BedrockBaseConfig(BaseModel):
class BedrockBaseConfig(RemoteInferenceProviderConfig):
aws_access_key_id: str | None = Field(
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",

View file

@ -11,12 +11,8 @@ import litellm
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
InferenceProvider,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
@ -24,12 +20,7 @@ from llama_stack.apis.inference import (
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
@ -37,8 +28,6 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe
from llama_stack.providers.utils.inference.openai_compat import (
b64_encode_openai_embeddings_response,
convert_message_to_openai_dict_new,
convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
convert_tooldef_to_openai_tool,
get_sampling_options,
prepare_openai_completion_params,
@ -105,57 +94,6 @@ class LiteLLMOpenAIMixin(
else model_id
)
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
params = await self._get_params(request)
params["model"] = self.get_litellm_model_name(params["model"])
logger.debug(f"params to litellm (openai compat): {params}")
# see https://docs.litellm.ai/docs/completion/stream#async-completion
response = await litellm.acompletion(**params)
if stream:
return self._stream_chat_completion(response)
else:
return convert_openai_chat_completion_choice(response.choices[0])
async def _stream_chat_completion(
self, response: litellm.ModelResponse
) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
async def _stream_generator():
async for chunk in response:
yield chunk
async for chunk in convert_openai_chat_completion_stream(
_stream_generator(), enable_incremental_tool_calls=True
):
yield chunk
def _add_additional_properties_recursive(self, schema):
"""
Recursively add additionalProperties: False to all object schemas

View file

@ -7,10 +7,11 @@
import base64
import uuid
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Iterable
from typing import Any
from openai import NOT_GIVEN, AsyncOpenAI
from pydantic import BaseModel, ConfigDict
from llama_stack.apis.inference import (
Model,
@ -26,14 +27,14 @@ from llama_stack.apis.inference import (
from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
logger = get_logger(name=__name__, category="providers::utils")
class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
"""
Mixin class that provides OpenAI-specific functionality for inference providers.
This class handles direct OpenAI API calls using the AsyncOpenAI client.
@ -42,12 +43,25 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
- get_api_key(): Method to retrieve the API key
- get_base_url(): Method to retrieve the OpenAI-compatible API base URL
The behavior of this class can be customized by child classes in the following ways:
- overwrite_completion_id: If True, overwrites the 'id' field in OpenAI responses
- download_images: If True, downloads images and converts to base64 for providers that require it
- embedding_model_metadata: A dictionary mapping model IDs to their embedding metadata
- provider_data_api_key_field: Optional field name in provider data to look for API key
- list_provider_model_ids: Method to list available models from the provider
- get_extra_client_params: Method to provide extra parameters to the AsyncOpenAI client
Expected Dependencies:
- self.model_store: Injected by the Llama Stack distribution system at runtime.
This provides model registry functionality for looking up registered models.
The model_store is set in routing_tables/common.py during provider initialization.
"""
# Allow extra fields so the routing infra can inject model_store, __provider_id__, etc.
model_config = ConfigDict(extra="allow")
config: RemoteInferenceProviderConfig
# Allow subclasses to control whether to overwrite the 'id' field in OpenAI responses
# is overwritten with a client-side generated id.
#
@ -108,6 +122,38 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
"""
return {}
async def list_provider_model_ids(self) -> Iterable[str]:
"""
List available models from the provider.
Child classes can override this method to provide a custom implementation
for listing models. The default implementation uses the AsyncOpenAI client
to list models from the OpenAI-compatible endpoint.
:return: An iterable of model IDs or None if not implemented
"""
return [m.id async for m in self.client.models.list()]
async def initialize(self) -> None:
"""
Initialize the OpenAI mixin.
This method provides a default implementation that does nothing.
Subclasses can override this method to perform initialization tasks
such as setting up clients, validating configurations, etc.
"""
pass
async def shutdown(self) -> None:
"""
Shutdown the OpenAI mixin.
This method provides a default implementation that does nothing.
Subclasses can override this method to perform cleanup tasks
such as closing connections, releasing resources, etc.
"""
pass
@property
def client(self) -> AsyncOpenAI:
"""
@ -356,6 +402,24 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
usage=usage,
)
###
# ModelsProtocolPrivate implementation - provide model management functionality
#
# async def register_model(self, model: Model) -> Model: ...
# async def unregister_model(self, model_id: str) -> None: ...
#
# async def list_models(self) -> list[Model] | None: ...
# async def should_refresh_models(self) -> bool: ...
##
async def register_model(self, model: Model) -> Model:
if not await self.check_model_availability(model.provider_model_id):
raise ValueError(f"Model {model.provider_model_id} is not available from provider {self.__provider_id__}") # type: ignore[attr-defined]
return model
async def unregister_model(self, model_id: str) -> None:
return None
async def list_models(self) -> list[Model] | None:
"""
List available models from the provider's /v1/models endpoint augmented with static embedding model metadata.
@ -366,28 +430,42 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
"""
self._model_cache = {}
async for m in self.client.models.list():
if self.allowed_models and m.id not in self.allowed_models:
logger.info(f"Skipping model {m.id} as it is not in the allowed models list")
try:
iterable = await self.list_provider_model_ids()
except Exception as e:
logger.error(f"{self.__class__.__name__}.list_provider_model_ids() failed with: {e}")
raise
if not hasattr(iterable, "__iter__"):
raise TypeError(
f"Failed to list models: {self.__class__.__name__}.list_provider_model_ids() must return an iterable of "
f"strings, but returned {type(iterable).__name__}"
)
provider_models_ids = list(iterable)
logger.info(f"{self.__class__.__name__}.list_provider_model_ids() returned {len(provider_models_ids)} models")
for provider_model_id in provider_models_ids:
if not isinstance(provider_model_id, str):
raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string")
if self.allowed_models and provider_model_id not in self.allowed_models:
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
continue
if metadata := self.embedding_model_metadata.get(m.id):
# This is an embedding model - augment with metadata
if metadata := self.embedding_model_metadata.get(provider_model_id):
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id,
identifier=m.id,
provider_resource_id=provider_model_id,
identifier=provider_model_id,
model_type=ModelType.embedding,
metadata=metadata,
)
else:
# This is an LLM
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id,
identifier=m.id,
provider_resource_id=provider_model_id,
identifier=provider_model_id,
model_type=ModelType.llm,
)
self._model_cache[m.id] = model
self._model_cache[provider_model_id] = model
return list(self._model_cache.values())
@ -400,5 +478,33 @@ class OpenAIMixin(ModelRegistryHelper, NeedsRequestProviderData, ABC):
"""
if not self._model_cache:
await self.list_models()
return model in self._model_cache
async def should_refresh_models(self) -> bool:
return False
#
# The model_dump implementations are to avoid serializing the extra fields,
# e.g. model_store, which are not pydantic.
#
def _filter_fields(self, **kwargs):
"""Helper to exclude extra fields from serialization."""
# Exclude any extra fields stored in __pydantic_extra__
if hasattr(self, "__pydantic_extra__") and self.__pydantic_extra__:
exclude = kwargs.get("exclude", set())
if not isinstance(exclude, set):
exclude = set(exclude) if exclude else set()
exclude.update(self.__pydantic_extra__.keys())
kwargs["exclude"] = exclude
return kwargs
def model_dump(self, **kwargs):
"""Override to exclude extra fields from serialization."""
kwargs = self._filter_fields(**kwargs)
return super().model_dump(**kwargs)
def model_dump_json(self, **kwargs):
"""Override to exclude extra fields from JSON serialization."""
kwargs = self._filter_fields(**kwargs)
return super().model_dump_json(**kwargs)

View file

@ -50,6 +50,7 @@ class ChunkForDeletion(BaseModel):
# Constants for reranker types
RERANKER_TYPE_RRF = "rrf"
RERANKER_TYPE_WEIGHTED = "weighted"
RERANKER_TYPE_NORMALIZED = "normalized"
def parse_pdf(data: bytes) -> str:
@ -325,6 +326,8 @@ class VectorDBWithIndex:
weights = ranker.get("params", {}).get("weights", [0.5, 0.5])
reranker_type = RERANKER_TYPE_WEIGHTED
reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5}
elif strategy == "normalized":
reranker_type = RERANKER_TYPE_NORMALIZED
else:
reranker_type = RERANKER_TYPE_RRF
k_value = ranker.get("params", {}).get("k", 60.0)

View file

@ -17,6 +17,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObject,
OpenAIResponseObjectWithInput,
)
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.log import get_logger
@ -28,6 +29,19 @@ from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreTy
logger = get_logger(name=__name__, category="openai_responses")
class _OpenAIResponseObjectWithInputAndMessages(OpenAIResponseObjectWithInput):
"""Internal class for storing responses with chat completion messages.
This extends the public OpenAIResponseObjectWithInput with messages field
for internal storage. The messages field is not exposed in the public API.
The messages field is optional for backward compatibility with responses
stored before this feature was added.
"""
messages: list[OpenAIMessageParam] | None = None
class ResponsesStore:
def __init__(
self,
@ -54,7 +68,9 @@ class ResponsesStore:
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
# Async write queue and worker control
self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None
self._queue: (
asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput], list[OpenAIMessageParam]]] | None
) = None
self._worker_tasks: list[asyncio.Task[Any]] = []
self._max_write_queue_size: int = config.max_write_queue_size
self._num_writers: int = max(1, config.num_writers)
@ -100,18 +116,21 @@ class ResponsesStore:
await self._queue.join()
async def store_response_object(
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
self,
response_object: OpenAIResponseObject,
input: list[OpenAIResponseInput],
messages: list[OpenAIMessageParam],
) -> None:
if self.enable_write_queue:
if self._queue is None:
raise ValueError("Responses store is not initialized")
try:
self._queue.put_nowait((response_object, input))
self._queue.put_nowait((response_object, input, messages))
except asyncio.QueueFull:
logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '<unknown>')}")
await self._queue.put((response_object, input))
await self._queue.put((response_object, input, messages))
else:
await self._write_response_object(response_object, input)
await self._write_response_object(response_object, input, messages)
async def _worker_loop(self) -> None:
assert self._queue is not None
@ -120,22 +139,26 @@ class ResponsesStore:
item = await self._queue.get()
except asyncio.CancelledError:
break
response_object, input = item
response_object, input, messages = item
try:
await self._write_response_object(response_object, input)
await self._write_response_object(response_object, input, messages)
except Exception as e: # noqa: BLE001
logger.error(f"Error writing response object: {e}")
finally:
self._queue.task_done()
async def _write_response_object(
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
self,
response_object: OpenAIResponseObject,
input: list[OpenAIResponseInput],
messages: list[OpenAIMessageParam],
) -> None:
if self.sql_store is None:
raise ValueError("Responses store is not initialized")
data = response_object.model_dump()
data["input"] = [input_item.model_dump() for input_item in input]
data["messages"] = [msg.model_dump() for msg in messages]
await self.sql_store.insert(
"openai_responses",
@ -188,7 +211,7 @@ class ResponsesStore:
last_id=data[-1].id if data else "",
)
async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWithInput:
async def get_response_object(self, response_id: str) -> _OpenAIResponseObjectWithInputAndMessages:
"""
Get a response object with automatic access control checking.
"""
@ -205,7 +228,7 @@ class ResponsesStore:
# This provides security by not revealing whether the record exists
raise ValueError(f"Response with id {response_id} not found") from None
return OpenAIResponseObjectWithInput(**row["response_object"])
return _OpenAIResponseObjectWithInputAndMessages(**row["response_object"])
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
if not self.sql_store:
@ -241,8 +264,8 @@ class ResponsesStore:
if before and after:
raise ValueError("Cannot specify both 'before' and 'after' parameters")
response_with_input = await self.get_response_object(response_id)
items = response_with_input.input
response_with_input_and_messages = await self.get_response_object(response_id)
items = response_with_input_and_messages.input
if order == Order.desc:
items = list(reversed(items))

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Any, Literal, Protocol
@ -41,9 +41,9 @@ class SqlStore(Protocol):
"""
pass
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
"""
Insert a row into a table.
Insert a row or batch of rows into a table.
"""
pass

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
@ -38,6 +38,18 @@ SQL_OPTIMIZED_POLICY = [
]
def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: User | None) -> Mapping[str, Any]:
"""Add access control attributes to a data item."""
enhanced = dict(item)
if current_user:
enhanced["owner_principal"] = current_user.principal
enhanced["access_attributes"] = current_user.attributes
else:
enhanced["owner_principal"] = None
enhanced["access_attributes"] = None
return enhanced
class SqlRecord(ProtectedResource):
def __init__(self, record_id: str, table_name: str, owner: User):
self.type = f"sql_record::{table_name}"
@ -102,18 +114,14 @@ class AuthorizedSqlStore:
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
"""Insert a row with automatic access control attribute capture."""
enhanced_data = dict(data)
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
"""Insert a row or batch of rows with automatic access control attribute capture."""
current_user = get_authenticated_user()
if current_user:
enhanced_data["owner_principal"] = current_user.principal
enhanced_data["access_attributes"] = current_user.attributes
enhanced_data: Mapping[str, Any] | Sequence[Mapping[str, Any]]
if isinstance(data, Mapping):
enhanced_data = _enhance_item_with_access_control(data, current_user)
else:
enhanced_data["owner_principal"] = None
enhanced_data["access_attributes"] = None
enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
await self.sql_store.insert(table, enhanced_data)
async def fetch_all(

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from sqlalchemy import (
@ -116,7 +116,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
async with engine.begin() as conn:
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
async with self.async_session() as session:
await session.execute(self.metadata.tables[table].insert(), data)
await session.commit()