diff --git a/llama_stack/providers/utils/inference/inference_store.py b/llama_stack/providers/utils/inference/inference_store.py index 8c69b1683..17f4c6268 100644 --- a/llama_stack/providers/utils/inference/inference_store.py +++ b/llama_stack/providers/utils/inference/inference_store.py @@ -6,6 +6,8 @@ import asyncio from typing import Any +from sqlalchemy.exc import IntegrityError + from llama_stack.apis.inference import ( ListOpenAIChatCompletionResponse, OpenAIChatCompletion, @@ -129,16 +131,44 @@ class InferenceStore: raise ValueError("Inference store is not initialized") data = chat_completion.model_dump() + record_data = { + "id": data["id"], + "created": data["created"], + "model": data["model"], + "choices": data["choices"], + "input_messages": [message.model_dump() for message in input_messages], + } - await self.sql_store.insert( - table="chat_completions", - data={ - "id": data["id"], - "created": data["created"], - "model": data["model"], - "choices": data["choices"], - "input_messages": [message.model_dump() for message in input_messages], - }, + try: + await self.sql_store.insert( + table="chat_completions", + data=record_data, + ) + except IntegrityError as e: + # Duplicate chat completion IDs can be generated during tests especially if they are replaying + # recorded responses across different tests. No need to warn or error under those circumstances. + # In the wild, this is not likely to happen at all (no evidence) so we aren't really hiding any problem. + + # Check if it's a unique constraint violation + error_message = str(e.orig) if e.orig else str(e) + if self._is_unique_constraint_error(error_message): + # Update the existing record instead + await self.sql_store.update(table="chat_completions", data=record_data, where={"id": data["id"]}) + else: + # Re-raise if it's not a unique constraint error + raise + + def _is_unique_constraint_error(self, error_message: str) -> bool: + """Check if the error is specifically a unique constraint violation.""" + error_lower = error_message.lower() + return any( + indicator in error_lower + for indicator in [ + "unique constraint failed", # SQLite + "duplicate key", # PostgreSQL + "unique violation", # PostgreSQL alternative + "duplicate entry", # MySQL + ] ) async def list_chat_completions( diff --git a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index 867ba2f55..acb688f96 100644 --- a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -172,6 +172,20 @@ class AuthorizedSqlStore: return results.data[0] if results.data else None + async def update(self, table: str, data: Mapping[str, Any], where: Mapping[str, Any]) -> None: + """Update rows with automatic access control attribute capture.""" + enhanced_data = dict(data) + + current_user = get_authenticated_user() + if current_user: + enhanced_data["owner_principal"] = current_user.principal + enhanced_data["access_attributes"] = current_user.attributes + else: + enhanced_data["owner_principal"] = None + enhanced_data["access_attributes"] = None + + await self.sql_store.update(table, enhanced_data, where) + async def delete(self, table: str, where: Mapping[str, Any]) -> None: """Delete rows with automatic access control filtering.""" await self.sql_store.delete(table, where)