fix(inference_store): on duplicate chat completion IDs, replace

This commit is contained in:
Ashwin Bharambe 2025-09-10 12:25:04 -07:00
parent c04f1c1e8c
commit 7eedb88edc
2 changed files with 52 additions and 9 deletions

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
from typing import Any from typing import Any
from sqlalchemy.exc import IntegrityError
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ListOpenAIChatCompletionResponse, ListOpenAIChatCompletionResponse,
@ -129,16 +130,44 @@ class InferenceStore:
raise ValueError("Inference store is not initialized") raise ValueError("Inference store is not initialized")
data = chat_completion.model_dump() data = chat_completion.model_dump()
record_data = {
await self.sql_store.insert(
table="chat_completions",
data={
"id": data["id"], "id": data["id"],
"created": data["created"], "created": data["created"],
"model": data["model"], "model": data["model"],
"choices": data["choices"], "choices": data["choices"],
"input_messages": [message.model_dump() for message in input_messages], "input_messages": [message.model_dump() for message in input_messages],
}, }
try:
await self.sql_store.insert(
table="chat_completions",
data=record_data,
)
except IntegrityError as e:
# Check if it's a unique constraint violation
error_message = str(e.orig) if e.orig else str(e)
if self._is_unique_constraint_error(error_message):
# Update the existing record instead
await self.sql_store.update(
table="chat_completions",
data=record_data,
where={"id": data["id"]}
)
else:
# Re-raise if it's not a unique constraint error
raise
def _is_unique_constraint_error(self, error_message: str) -> bool:
"""Check if the error is specifically a unique constraint violation."""
error_lower = error_message.lower()
return any(
indicator in error_lower
for indicator in [
"unique constraint failed", # SQLite
"duplicate key", # PostgreSQL
"unique violation", # PostgreSQL alternative
"duplicate entry", # MySQL
]
) )
async def list_chat_completions( async def list_chat_completions(

View file

@ -172,6 +172,20 @@ class AuthorizedSqlStore:
return results.data[0] if results.data else None return results.data[0] if results.data else None
async def update(self, table: str, data: Mapping[str, Any], where: Mapping[str, Any]) -> None:
"""Update rows with automatic access control attribute capture."""
enhanced_data = dict(data)
current_user = get_authenticated_user()
if current_user:
enhanced_data["owner_principal"] = current_user.principal
enhanced_data["access_attributes"] = current_user.attributes
else:
enhanced_data["owner_principal"] = None
enhanced_data["access_attributes"] = None
await self.sql_store.update(table, enhanced_data, where)
async def delete(self, table: str, where: Mapping[str, Any]) -> None: async def delete(self, table: str, where: Mapping[str, Any]) -> None:
"""Delete rows with automatic access control filtering.""" """Delete rows with automatic access control filtering."""
await self.sql_store.delete(table, where) await self.sql_store.delete(table, where)