From 0c7f49490cdb6ff757659469d1401b515ac4402c Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 10 Sep 2025 14:34:18 -0700 Subject: [PATCH] fix(inference_store): on duplicate chat completion IDs, replace (#3408) # What does this PR do? Duplicate chat completion IDs can be generated during tests especially if they are replaying recorded responses across different tests. No need to warn or error under those circumstances. In the wild, this is not likely to happen at all (no evidence) so we aren't really hiding any problem. --- .../utils/inference/inference_store.py | 48 +++++++++++++++---- .../utils/sqlstore/authorized_sqlstore.py | 14 ++++++ 2 files changed, 53 insertions(+), 9 deletions(-) diff --git a/llama_stack/providers/utils/inference/inference_store.py b/llama_stack/providers/utils/inference/inference_store.py index 8c69b1683..17f4c6268 100644 --- a/llama_stack/providers/utils/inference/inference_store.py +++ b/llama_stack/providers/utils/inference/inference_store.py @@ -6,6 +6,8 @@ import asyncio from typing import Any +from sqlalchemy.exc import IntegrityError + from llama_stack.apis.inference import ( ListOpenAIChatCompletionResponse, OpenAIChatCompletion, @@ -129,16 +131,44 @@ class InferenceStore: raise ValueError("Inference store is not initialized") data = chat_completion.model_dump() + record_data = { + "id": data["id"], + "created": data["created"], + "model": data["model"], + "choices": data["choices"], + "input_messages": [message.model_dump() for message in input_messages], + } - await self.sql_store.insert( - table="chat_completions", - data={ - "id": data["id"], - "created": data["created"], - "model": data["model"], - "choices": data["choices"], - "input_messages": [message.model_dump() for message in input_messages], - }, + try: + await self.sql_store.insert( + table="chat_completions", + data=record_data, + ) + except IntegrityError as e: + # Duplicate chat completion IDs can be generated during tests especially if they are replaying + # recorded responses across different tests. No need to warn or error under those circumstances. + # In the wild, this is not likely to happen at all (no evidence) so we aren't really hiding any problem. + + # Check if it's a unique constraint violation + error_message = str(e.orig) if e.orig else str(e) + if self._is_unique_constraint_error(error_message): + # Update the existing record instead + await self.sql_store.update(table="chat_completions", data=record_data, where={"id": data["id"]}) + else: + # Re-raise if it's not a unique constraint error + raise + + def _is_unique_constraint_error(self, error_message: str) -> bool: + """Check if the error is specifically a unique constraint violation.""" + error_lower = error_message.lower() + return any( + indicator in error_lower + for indicator in [ + "unique constraint failed", # SQLite + "duplicate key", # PostgreSQL + "unique violation", # PostgreSQL alternative + "duplicate entry", # MySQL + ] ) async def list_chat_completions( diff --git a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index 867ba2f55..acb688f96 100644 --- a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -172,6 +172,20 @@ class AuthorizedSqlStore: return results.data[0] if results.data else None + async def update(self, table: str, data: Mapping[str, Any], where: Mapping[str, Any]) -> None: + """Update rows with automatic access control attribute capture.""" + enhanced_data = dict(data) + + current_user = get_authenticated_user() + if current_user: + enhanced_data["owner_principal"] = current_user.principal + enhanced_data["access_attributes"] = current_user.attributes + else: + enhanced_data["owner_principal"] = None + enhanced_data["access_attributes"] = None + + await self.sql_store.update(table, enhanced_data, where) + async def delete(self, table: str, where: Mapping[str, Any]) -> None: """Delete rows with automatic access control filtering.""" await self.sql_store.delete(table, where)