mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 20:12:33 +00:00
# What does this PR do? The InferenceStore class was ignoring the table_name field from InferenceStoreReference and always using the hardcoded value "chat_completions". This meant that any custom table_name configured in the run config (e.g., "inference_store" in run-with-postgres-store.yaml) was silently ignored. This change updates all SQL operations in InferenceStore to use self.reference.table_name instead of the hardcoded string, ensuring the configured table name is properly respected. A new test has been added to verify that custom table names work correctly for storing, retrieving, and listing chat completions. ## Test Plan CI <hr>This is an automatic backport of pull request #4371 done by [Mergify](https://mergify.com). Signed-off-by: Sébastien Han <seb@redhat.com> Co-authored-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
254adf2e10
commit
a6a600f845
2 changed files with 34 additions and 5 deletions
|
|
@ -56,7 +56,7 @@ class InferenceStore:
|
||||||
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
|
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
|
||||||
|
|
||||||
await self.sql_store.create_table(
|
await self.sql_store.create_table(
|
||||||
"chat_completions",
|
self.reference.table_name,
|
||||||
{
|
{
|
||||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||||
"created": ColumnType.INTEGER,
|
"created": ColumnType.INTEGER,
|
||||||
|
|
@ -161,7 +161,7 @@ class InferenceStore:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.sql_store.insert(
|
await self.sql_store.insert(
|
||||||
table="chat_completions",
|
table=self.reference.table_name,
|
||||||
data=record_data,
|
data=record_data,
|
||||||
)
|
)
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
|
|
@ -173,7 +173,7 @@ class InferenceStore:
|
||||||
error_message = str(e.orig) if e.orig else str(e)
|
error_message = str(e.orig) if e.orig else str(e)
|
||||||
if self._is_unique_constraint_error(error_message):
|
if self._is_unique_constraint_error(error_message):
|
||||||
# Update the existing record instead
|
# Update the existing record instead
|
||||||
await self.sql_store.update(table="chat_completions", data=record_data, where={"id": data["id"]})
|
await self.sql_store.update(table=self.reference.table_name, data=record_data, where={"id": data["id"]})
|
||||||
else:
|
else:
|
||||||
# Re-raise if it's not a unique constraint error
|
# Re-raise if it's not a unique constraint error
|
||||||
raise
|
raise
|
||||||
|
|
@ -217,7 +217,7 @@ class InferenceStore:
|
||||||
where_conditions["model"] = model
|
where_conditions["model"] = model
|
||||||
|
|
||||||
paginated_result = await self.sql_store.fetch_all(
|
paginated_result = await self.sql_store.fetch_all(
|
||||||
table="chat_completions",
|
table=self.reference.table_name,
|
||||||
where=where_conditions if where_conditions else None,
|
where=where_conditions if where_conditions else None,
|
||||||
order_by=[("created", order.value)],
|
order_by=[("created", order.value)],
|
||||||
cursor=("id", after) if after else None,
|
cursor=("id", after) if after else None,
|
||||||
|
|
@ -246,7 +246,7 @@ class InferenceStore:
|
||||||
raise ValueError("Inference store is not initialized")
|
raise ValueError("Inference store is not initialized")
|
||||||
|
|
||||||
row = await self.sql_store.fetch_one(
|
row = await self.sql_store.fetch_one(
|
||||||
table="chat_completions",
|
table=self.reference.table_name,
|
||||||
where={"id": completion_id},
|
where={"id": completion_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -210,3 +210,32 @@ async def test_inference_store_pagination_no_limit():
|
||||||
assert result.data[0].id == "beta-second" # Most recent first
|
assert result.data[0].id == "beta-second" # Most recent first
|
||||||
assert result.data[1].id == "omega-first"
|
assert result.data[1].id == "omega-first"
|
||||||
assert result.has_more is False
|
assert result.has_more is False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_inference_store_custom_table_name():
|
||||||
|
"""Test that the table_name from config is respected."""
|
||||||
|
custom_table_name = "custom_inference_store"
|
||||||
|
reference = InferenceStoreReference(backend="sql_default", table_name=custom_table_name)
|
||||||
|
store = InferenceStore(reference, policy=[])
|
||||||
|
await store.initialize()
|
||||||
|
|
||||||
|
# Create and store a test chat completion
|
||||||
|
base_time = int(time.time())
|
||||||
|
completion = create_test_chat_completion("custom-table-test", base_time)
|
||||||
|
input_messages = [OpenAIUserMessageParam(role="user", content="Test custom table")]
|
||||||
|
await store.store_chat_completion(completion, input_messages)
|
||||||
|
await store.flush()
|
||||||
|
|
||||||
|
# Verify we can retrieve the completion
|
||||||
|
result = await store.get_chat_completion("custom-table-test")
|
||||||
|
assert result.id == "custom-table-test"
|
||||||
|
assert result.model == "test-model"
|
||||||
|
|
||||||
|
# Verify listing works
|
||||||
|
list_result = await store.list_chat_completions()
|
||||||
|
assert len(list_result.data) == 1
|
||||||
|
assert list_result.data[0].id == "custom-table-test"
|
||||||
|
|
||||||
|
# Verify the error message uses the custom table name
|
||||||
|
with pytest.raises(ValueError, match=f"Record with id='non-existent' not found in table '{custom_table_name}'"):
|
||||||
|
await store.list_chat_completions(after="non-existent", limit=2)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue