mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-25 09:05:37 +00:00 
			
		
		
		
	# What does this PR do? * Provide sqlite implementation of the APIs introduced in https://github.com/meta-llama/llama-stack/pull/2145. * Introduced a SqlStore API: llama_stack/providers/utils/sqlstore/api.py and the first Sqlite implementation * Pagination support will be added in a future PR. ## Test Plan Unit test on sql store: <img width="1005" alt="image" src="https://github.com/user-attachments/assets/9b8b7ec8-632b-4667-8127-5583426b2e29" /> Integration test: ``` INFERENCE_MODEL="llama3.2:3b-instruct-fp16" llama stack build --template ollama --image-type conda --run ``` ``` LLAMA_STACK_CONFIG=http://localhost:5001 INFERENCE_MODEL="llama3.2:3b-instruct-fp16" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "llama3.2:3b-instruct-fp16" -k 'inference_store and openai' ```
		
			
				
	
	
		
			123 lines
		
	
	
	
		
			4.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			123 lines
		
	
	
	
		
			4.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| from llama_stack.apis.inference import (
 | |
|     ListOpenAIChatCompletionResponse,
 | |
|     OpenAIChatCompletion,
 | |
|     OpenAICompletionWithInputMessages,
 | |
|     OpenAIMessageParam,
 | |
|     Order,
 | |
| )
 | |
| from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
 | |
| 
 | |
| from ..sqlstore.api import ColumnDefinition, ColumnType
 | |
| from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
 | |
| 
 | |
| 
 | |
| class InferenceStore:
 | |
|     def __init__(self, sql_store_config: SqlStoreConfig):
 | |
|         if not sql_store_config:
 | |
|             sql_store_config = SqliteSqlStoreConfig(
 | |
|                 db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
 | |
|             )
 | |
|         self.sql_store_config = sql_store_config
 | |
|         self.sql_store = None
 | |
| 
 | |
|     async def initialize(self):
 | |
|         """Create the necessary tables if they don't exist."""
 | |
|         self.sql_store = sqlstore_impl(self.sql_store_config)
 | |
|         await self.sql_store.create_table(
 | |
|             "chat_completions",
 | |
|             {
 | |
|                 "id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
 | |
|                 "created": ColumnType.INTEGER,
 | |
|                 "model": ColumnType.STRING,
 | |
|                 "choices": ColumnType.JSON,
 | |
|                 "input_messages": ColumnType.JSON,
 | |
|             },
 | |
|         )
 | |
| 
 | |
|     async def store_chat_completion(
 | |
|         self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
 | |
|     ) -> None:
 | |
|         if not self.sql_store:
 | |
|             raise ValueError("Inference store is not initialized")
 | |
| 
 | |
|         data = chat_completion.model_dump()
 | |
| 
 | |
|         await self.sql_store.insert(
 | |
|             "chat_completions",
 | |
|             {
 | |
|                 "id": data["id"],
 | |
|                 "created": data["created"],
 | |
|                 "model": data["model"],
 | |
|                 "choices": data["choices"],
 | |
|                 "input_messages": [message.model_dump() for message in input_messages],
 | |
|             },
 | |
|         )
 | |
| 
 | |
|     async def list_chat_completions(
 | |
|         self,
 | |
|         after: str | None = None,
 | |
|         limit: int | None = 50,
 | |
|         model: str | None = None,
 | |
|         order: Order | None = Order.desc,
 | |
|     ) -> ListOpenAIChatCompletionResponse:
 | |
|         """
 | |
|         List chat completions from the database.
 | |
| 
 | |
|         :param after: The ID of the last chat completion to return.
 | |
|         :param limit: The maximum number of chat completions to return.
 | |
|         :param model: The model to filter by.
 | |
|         :param order: The order to sort the chat completions by.
 | |
|         """
 | |
|         if not self.sql_store:
 | |
|             raise ValueError("Inference store is not initialized")
 | |
| 
 | |
|         # TODO: support after
 | |
|         if after:
 | |
|             raise NotImplementedError("After is not supported for SQLite")
 | |
|         if not order:
 | |
|             order = Order.desc
 | |
| 
 | |
|         rows = await self.sql_store.fetch_all(
 | |
|             "chat_completions",
 | |
|             where={"model": model} if model else None,
 | |
|             order_by=[("created", order.value)],
 | |
|             limit=limit,
 | |
|         )
 | |
| 
 | |
|         data = [
 | |
|             OpenAICompletionWithInputMessages(
 | |
|                 id=row["id"],
 | |
|                 created=row["created"],
 | |
|                 model=row["model"],
 | |
|                 choices=row["choices"],
 | |
|                 input_messages=row["input_messages"],
 | |
|             )
 | |
|             for row in rows
 | |
|         ]
 | |
|         return ListOpenAIChatCompletionResponse(
 | |
|             data=data,
 | |
|             # TODO: implement has_more
 | |
|             has_more=False,
 | |
|             first_id=data[0].id if data else "",
 | |
|             last_id=data[-1].id if data else "",
 | |
|         )
 | |
| 
 | |
|     async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
 | |
|         if not self.sql_store:
 | |
|             raise ValueError("Inference store is not initialized")
 | |
| 
 | |
|         row = await self.sql_store.fetch_one("chat_completions", where={"id": completion_id})
 | |
|         if not row:
 | |
|             raise ValueError(f"Chat completion with id {completion_id} not found") from None
 | |
|         return OpenAICompletionWithInputMessages(
 | |
|             id=row["id"],
 | |
|             created=row["created"],
 | |
|             model=row["model"],
 | |
|             choices=row["choices"],
 | |
|             input_messages=row["input_messages"],
 | |
|         )
 |