mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: implement get chat completions APIs (#2200)
# What does this PR do? * Provide sqlite implementation of the APIs introduced in https://github.com/meta-llama/llama-stack/pull/2145. * Introduced a SqlStore API: llama_stack/providers/utils/sqlstore/api.py and the first Sqlite implementation * Pagination support will be added in a future PR. ## Test Plan Unit test on sql store: <img width="1005" alt="image" src="https://github.com/user-attachments/assets/9b8b7ec8-632b-4667-8127-5583426b2e29" /> Integration test: ``` INFERENCE_MODEL="llama3.2:3b-instruct-fp16" llama stack build --template ollama --image-type conda --run ``` ``` LLAMA_STACK_CONFIG=http://localhost:5001 INFERENCE_MODEL="llama3.2:3b-instruct-fp16" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "llama3.2:3b-instruct-fp16" -k 'inference_store and openai' ```
This commit is contained in:
parent
633bb9c5b3
commit
549812f51e
71 changed files with 1111 additions and 10 deletions
|
@ -43,8 +43,20 @@ def get_provider_dependencies(
|
|||
# Extract providers based on config type
|
||||
if isinstance(config, DistributionTemplate):
|
||||
providers = config.providers
|
||||
|
||||
# TODO: This is a hack to get the dependencies for internal APIs into build
|
||||
# We should have a better way to do this by formalizing the concept of "internal" APIs
|
||||
# and providers, with a way to specify dependencies for them.
|
||||
run_configs = config.run_configs
|
||||
additional_pip_packages: list[str] = []
|
||||
if run_configs:
|
||||
for run_config in run_configs.values():
|
||||
run_config_ = run_config.run_config(name="", providers={}, container_image=None)
|
||||
if run_config_.inference_store:
|
||||
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
|
||||
elif isinstance(config, BuildConfig):
|
||||
providers = config.distribution_spec.providers
|
||||
additional_pip_packages = config.additional_pip_packages
|
||||
deps = []
|
||||
registry = get_provider_registry(config)
|
||||
for api_str, provider_or_providers in providers.items():
|
||||
|
@ -72,6 +84,9 @@ def get_provider_dependencies(
|
|||
else:
|
||||
normal_deps.append(package)
|
||||
|
||||
if additional_pip_packages:
|
||||
normal_deps.extend(additional_pip_packages)
|
||||
|
||||
return list(set(normal_deps)), list(set(special_deps))
|
||||
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
|||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
|
||||
|
||||
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||
|
@ -314,6 +315,13 @@ Configuration for the persistence store used by the distribution registry. If no
|
|||
a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
inference_store: SqlStoreConfig | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Configuration for the persistence store used by the inference API. If not specified,
|
||||
a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
# registry of "resources" in the distribution
|
||||
models: list[ModelInput] = Field(default_factory=list)
|
||||
shields: list[ShieldInput] = Field(default_factory=list)
|
||||
|
@ -362,6 +370,10 @@ class BuildConfig(BaseModel):
|
|||
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
||||
"pip_packages MUST contain the provider package name.",
|
||||
)
|
||||
additional_pip_packages: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.",
|
||||
)
|
||||
|
||||
@field_validator("external_providers_dir")
|
||||
@classmethod
|
||||
|
|
|
@ -140,7 +140,7 @@ async def resolve_impls(
|
|||
|
||||
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
||||
|
||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
|
||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config)
|
||||
|
||||
|
||||
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||
|
@ -243,7 +243,10 @@ def sort_providers_by_deps(
|
|||
|
||||
|
||||
async def instantiate_providers(
|
||||
sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry
|
||||
sorted_providers: list[tuple[str, ProviderWithSpec]],
|
||||
router_apis: set[Api],
|
||||
dist_registry: DistributionRegistry,
|
||||
run_config: StackRunConfig,
|
||||
) -> dict:
|
||||
"""Instantiates providers asynchronously while managing dependencies."""
|
||||
impls: dict[Api, Any] = {}
|
||||
|
@ -258,7 +261,7 @@ async def instantiate_providers(
|
|||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
||||
|
||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
|
||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config)
|
||||
|
||||
if api_str.startswith("inner-"):
|
||||
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
||||
|
@ -308,6 +311,7 @@ async def instantiate_provider(
|
|||
deps: dict[Api, Any],
|
||||
inner_impls: dict[str, Any],
|
||||
dist_registry: DistributionRegistry,
|
||||
run_config: StackRunConfig,
|
||||
):
|
||||
provider_spec = provider.spec
|
||||
if not hasattr(provider_spec, "module"):
|
||||
|
@ -327,7 +331,7 @@ async def instantiate_provider(
|
|||
method = "get_auto_router_impl"
|
||||
|
||||
config = None
|
||||
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
|
||||
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config]
|
||||
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
||||
method = "get_routing_table_impl"
|
||||
|
||||
|
|
|
@ -7,8 +7,10 @@
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutedProtocol
|
||||
from llama_stack.distribution.stack import StackRunConfig
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
|
||||
from .routing_tables import (
|
||||
BenchmarksRoutingTable,
|
||||
|
@ -45,7 +47,9 @@ async def get_routing_table_impl(
|
|||
return impl
|
||||
|
||||
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict[str, Any]) -> Any:
|
||||
async def get_auto_router_impl(
|
||||
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig
|
||||
) -> Any:
|
||||
from .routers import (
|
||||
DatasetIORouter,
|
||||
EvalRouter,
|
||||
|
@ -76,6 +80,12 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict
|
|||
if dep_api in deps:
|
||||
api_to_dep_impl[dep_name] = deps[dep_api]
|
||||
|
||||
# TODO: move pass configs to routers instead
|
||||
if api == Api.inference and run_config.inference_store:
|
||||
inference_store = InferenceStore(run_config.inference_store)
|
||||
await inference_store.initialize()
|
||||
api_to_dep_impl["store"] = inference_store
|
||||
|
||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -32,8 +32,11 @@ from llama_stack.apis.inference import (
|
|||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
ListOpenAIChatCompletionResponse,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAICompletionWithInputMessages,
|
||||
Order,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
|
@ -73,6 +76,8 @@ from llama_stack.log import get_logger
|
|||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
@ -141,10 +146,12 @@ class InferenceRouter(Inference):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
telemetry: Telemetry | None = None,
|
||||
store: InferenceStore | None = None,
|
||||
) -> None:
|
||||
logger.debug("Initializing InferenceRouter")
|
||||
self.routing_table = routing_table
|
||||
self.telemetry = telemetry
|
||||
self.store = store
|
||||
if self.telemetry:
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
|
@ -607,9 +614,31 @@ class InferenceRouter(Inference):
|
|||
|
||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if stream:
|
||||
return await provider.openai_chat_completion(**params)
|
||||
response_stream = await provider.openai_chat_completion(**params)
|
||||
if self.store:
|
||||
return stream_and_store_openai_completion(response_stream, model, self.store, messages)
|
||||
return response_stream
|
||||
else:
|
||||
return await self._nonstream_openai_chat_completion(provider, params)
|
||||
response = await self._nonstream_openai_chat_completion(provider, params)
|
||||
if self.store:
|
||||
await self.store.store_chat_completion(response, messages)
|
||||
return response
|
||||
|
||||
async def list_chat_completions(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 20,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIChatCompletionResponse:
|
||||
if self.store:
|
||||
return await self.store.list_chat_completions(after, limit, model, order)
|
||||
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
|
||||
|
||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||
if self.store:
|
||||
return await self.store.get_chat_completion(completion_id)
|
||||
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
|
||||
|
||||
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
|
||||
response = await provider.openai_chat_completion(**params)
|
||||
|
|
123
llama_stack/providers/utils/inference/inference_store.py
Normal file
123
llama_stack/providers/utils/inference/inference_store.py
Normal file
|
@ -0,0 +1,123 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.inference import (
|
||||
ListOpenAIChatCompletionResponse,
|
||||
OpenAIChatCompletion,
|
||||
OpenAICompletionWithInputMessages,
|
||||
OpenAIMessageParam,
|
||||
Order,
|
||||
)
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
|
||||
from ..sqlstore.api import ColumnDefinition, ColumnType
|
||||
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
|
||||
|
||||
|
||||
class InferenceStore:
|
||||
def __init__(self, sql_store_config: SqlStoreConfig):
|
||||
if not sql_store_config:
|
||||
sql_store_config = SqliteSqlStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||
)
|
||||
self.sql_store_config = sql_store_config
|
||||
self.sql_store = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
self.sql_store = sqlstore_impl(self.sql_store_config)
|
||||
await self.sql_store.create_table(
|
||||
"chat_completions",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"created": ColumnType.INTEGER,
|
||||
"model": ColumnType.STRING,
|
||||
"choices": ColumnType.JSON,
|
||||
"input_messages": ColumnType.JSON,
|
||||
},
|
||||
)
|
||||
|
||||
async def store_chat_completion(
|
||||
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
|
||||
) -> None:
|
||||
if not self.sql_store:
|
||||
raise ValueError("Inference store is not initialized")
|
||||
|
||||
data = chat_completion.model_dump()
|
||||
|
||||
await self.sql_store.insert(
|
||||
"chat_completions",
|
||||
{
|
||||
"id": data["id"],
|
||||
"created": data["created"],
|
||||
"model": data["model"],
|
||||
"choices": data["choices"],
|
||||
"input_messages": [message.model_dump() for message in input_messages],
|
||||
},
|
||||
)
|
||||
|
||||
async def list_chat_completions(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 50,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIChatCompletionResponse:
|
||||
"""
|
||||
List chat completions from the database.
|
||||
|
||||
:param after: The ID of the last chat completion to return.
|
||||
:param limit: The maximum number of chat completions to return.
|
||||
:param model: The model to filter by.
|
||||
:param order: The order to sort the chat completions by.
|
||||
"""
|
||||
if not self.sql_store:
|
||||
raise ValueError("Inference store is not initialized")
|
||||
|
||||
# TODO: support after
|
||||
if after:
|
||||
raise NotImplementedError("After is not supported for SQLite")
|
||||
if not order:
|
||||
order = Order.desc
|
||||
|
||||
rows = await self.sql_store.fetch_all(
|
||||
"chat_completions",
|
||||
where={"model": model} if model else None,
|
||||
order_by=[("created", order.value)],
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
data = [
|
||||
OpenAICompletionWithInputMessages(
|
||||
id=row["id"],
|
||||
created=row["created"],
|
||||
model=row["model"],
|
||||
choices=row["choices"],
|
||||
input_messages=row["input_messages"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
return ListOpenAIChatCompletionResponse(
|
||||
data=data,
|
||||
# TODO: implement has_more
|
||||
has_more=False,
|
||||
first_id=data[0].id if data else "",
|
||||
last_id=data[-1].id if data else "",
|
||||
)
|
||||
|
||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||
if not self.sql_store:
|
||||
raise ValueError("Inference store is not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one("chat_completions", where={"id": completion_id})
|
||||
if not row:
|
||||
raise ValueError(f"Chat completion with id {completion_id} not found") from None
|
||||
return OpenAICompletionWithInputMessages(
|
||||
id=row["id"],
|
||||
created=row["created"],
|
||||
model=row["model"],
|
||||
choices=row["choices"],
|
||||
input_messages=row["input_messages"],
|
||||
)
|
129
llama_stack/providers/utils/inference/stream_utils.py
Normal file
129
llama_stack/providers/utils/inference/stream_utils.py
Normal file
|
@ -0,0 +1,129 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIChoiceLogprobs,
|
||||
OpenAIMessageParam,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
|
||||
|
||||
async def stream_and_store_openai_completion(
|
||||
provider_stream: AsyncIterator[OpenAIChatCompletionChunk],
|
||||
model: str,
|
||||
store: InferenceStore,
|
||||
input_messages: list[OpenAIMessageParam],
|
||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""
|
||||
Wraps a provider's stream, yields chunks, and stores the full completion at the end.
|
||||
"""
|
||||
id = None
|
||||
created = None
|
||||
choices_data: dict[int, dict[str, Any]] = {}
|
||||
|
||||
try:
|
||||
async for chunk in provider_stream:
|
||||
if id is None and chunk.id:
|
||||
id = chunk.id
|
||||
if created is None and chunk.created:
|
||||
created = chunk.created
|
||||
|
||||
if chunk.choices:
|
||||
for choice_delta in chunk.choices:
|
||||
idx = choice_delta.index
|
||||
if idx not in choices_data:
|
||||
choices_data[idx] = {
|
||||
"content_parts": [],
|
||||
"tool_calls_builder": {},
|
||||
"finish_reason": None,
|
||||
"logprobs_content_parts": [],
|
||||
}
|
||||
current_choice_data = choices_data[idx]
|
||||
|
||||
if choice_delta.delta:
|
||||
delta = choice_delta.delta
|
||||
if delta.content:
|
||||
current_choice_data["content_parts"].append(delta.content)
|
||||
if delta.tool_calls:
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
tc_idx = tool_call_delta.index
|
||||
if tc_idx not in current_choice_data["tool_calls_builder"]:
|
||||
# Initialize with correct structure for _ToolCallBuilderData
|
||||
current_choice_data["tool_calls_builder"][tc_idx] = {
|
||||
"id": None,
|
||||
"type": "function",
|
||||
"function_name_parts": [],
|
||||
"function_arguments_parts": [],
|
||||
}
|
||||
builder = current_choice_data["tool_calls_builder"][tc_idx]
|
||||
if tool_call_delta.id:
|
||||
builder["id"] = tool_call_delta.id
|
||||
if tool_call_delta.type:
|
||||
builder["type"] = tool_call_delta.type
|
||||
if tool_call_delta.function:
|
||||
if tool_call_delta.function.name:
|
||||
builder["function_name_parts"].append(tool_call_delta.function.name)
|
||||
if tool_call_delta.function.arguments:
|
||||
builder["function_arguments_parts"].append(tool_call_delta.function.arguments)
|
||||
if choice_delta.finish_reason:
|
||||
current_choice_data["finish_reason"] = choice_delta.finish_reason
|
||||
if choice_delta.logprobs and choice_delta.logprobs.content:
|
||||
# Ensure that we are extending with the correct type
|
||||
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
|
||||
yield chunk
|
||||
finally:
|
||||
if id:
|
||||
assembled_choices: list[OpenAIChoice] = []
|
||||
for choice_idx, choice_data in choices_data.items():
|
||||
content_str = "".join(choice_data["content_parts"])
|
||||
assembled_tool_calls: list[OpenAIChatCompletionToolCall] = []
|
||||
if choice_data["tool_calls_builder"]:
|
||||
for tc_build_data in choice_data["tool_calls_builder"].values():
|
||||
if tc_build_data["id"]:
|
||||
func_name = "".join(tc_build_data["function_name_parts"])
|
||||
func_args = "".join(tc_build_data["function_arguments_parts"])
|
||||
assembled_tool_calls.append(
|
||||
OpenAIChatCompletionToolCall(
|
||||
id=tc_build_data["id"],
|
||||
type=tc_build_data["type"], # No or "function" needed, already set
|
||||
function=OpenAIChatCompletionToolCallFunction(name=func_name, arguments=func_args),
|
||||
)
|
||||
)
|
||||
message = OpenAIAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=content_str if content_str else None,
|
||||
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
|
||||
)
|
||||
logprobs_content = choice_data["logprobs_content_parts"]
|
||||
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
|
||||
|
||||
assembled_choices.append(
|
||||
OpenAIChoice(
|
||||
finish_reason=choice_data["finish_reason"],
|
||||
index=choice_idx,
|
||||
message=message,
|
||||
logprobs=final_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
final_response = OpenAIChatCompletion(
|
||||
id=id,
|
||||
choices=assembled_choices,
|
||||
created=created or int(datetime.now(timezone.utc).timestamp()),
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
)
|
||||
await store.store_chat_completion(final_response, input_messages)
|
90
llama_stack/providers/utils/sqlstore/api.py
Normal file
90
llama_stack/providers/utils/sqlstore/api.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ColumnType(Enum):
|
||||
INTEGER = "INTEGER"
|
||||
STRING = "STRING"
|
||||
TEXT = "TEXT"
|
||||
FLOAT = "FLOAT"
|
||||
BOOLEAN = "BOOLEAN"
|
||||
JSON = "JSON"
|
||||
DATETIME = "DATETIME"
|
||||
|
||||
|
||||
class ColumnDefinition(BaseModel):
|
||||
type: ColumnType
|
||||
primary_key: bool = False
|
||||
nullable: bool = True
|
||||
default: Any = None
|
||||
|
||||
|
||||
class SqlStore(Protocol):
|
||||
"""
|
||||
A protocol for a SQL store.
|
||||
"""
|
||||
|
||||
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
|
||||
"""
|
||||
Create a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
||||
"""
|
||||
Insert a row into a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch all rows from a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def fetch_one(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any] | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Fetch one row from a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def update(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
where: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Update a row in a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Delete a row from a table.
|
||||
"""
|
||||
pass
|
161
llama_stack/providers/utils/sqlstore/sqlite/sqlite.py
Normal file
161
llama_stack/providers/utils/sqlstore/sqlite/sqlite.py
Normal file
|
@ -0,0 +1,161 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Float,
|
||||
Integer,
|
||||
MetaData,
|
||||
String,
|
||||
Table,
|
||||
Text,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from ..api import ColumnDefinition, ColumnType, SqlStore
|
||||
from ..sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
TYPE_MAPPING: dict[ColumnType, Any] = {
|
||||
ColumnType.INTEGER: Integer,
|
||||
ColumnType.STRING: String,
|
||||
ColumnType.FLOAT: Float,
|
||||
ColumnType.BOOLEAN: Boolean,
|
||||
ColumnType.DATETIME: DateTime,
|
||||
ColumnType.TEXT: Text,
|
||||
ColumnType.JSON: JSON,
|
||||
}
|
||||
|
||||
|
||||
class SqliteSqlStoreImpl(SqlStore):
|
||||
def __init__(self, config: SqliteSqlStoreConfig):
|
||||
self.engine = create_async_engine(config.engine_str)
|
||||
self.metadata = MetaData()
|
||||
|
||||
async def create_table(
|
||||
self,
|
||||
table: str,
|
||||
schema: Mapping[str, ColumnType | ColumnDefinition],
|
||||
) -> None:
|
||||
if not schema:
|
||||
raise ValueError(f"No columns defined for table '{table}'.")
|
||||
|
||||
sqlalchemy_columns: list[Column] = []
|
||||
|
||||
for col_name, col_props in schema.items():
|
||||
col_type = None
|
||||
is_primary_key = False
|
||||
is_nullable = True # Default to nullable
|
||||
|
||||
if isinstance(col_props, ColumnType):
|
||||
col_type = col_props
|
||||
elif isinstance(col_props, ColumnDefinition):
|
||||
col_type = col_props.type
|
||||
is_primary_key = col_props.primary_key
|
||||
is_nullable = col_props.nullable
|
||||
|
||||
sqlalchemy_type = TYPE_MAPPING.get(col_type)
|
||||
if not sqlalchemy_type:
|
||||
raise ValueError(f"Unsupported column type '{col_type}' for column '{col_name}'.")
|
||||
|
||||
sqlalchemy_columns.append(
|
||||
Column(col_name, sqlalchemy_type, primary_key=is_primary_key, nullable=is_nullable)
|
||||
)
|
||||
|
||||
# Check if table already exists in metadata, otherwise define it
|
||||
if table not in self.metadata.tables:
|
||||
sqlalchemy_table = Table(table, self.metadata, *sqlalchemy_columns)
|
||||
else:
|
||||
sqlalchemy_table = self.metadata.tables[table]
|
||||
|
||||
# Create the table in the database if it doesn't exist
|
||||
# checkfirst=True ensures it doesn't try to recreate if it's already there
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.execute(self.metadata.tables[table].insert(), data)
|
||||
await conn.commit()
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
async with self.engine.begin() as conn:
|
||||
query = select(self.metadata.tables[table])
|
||||
if where:
|
||||
for key, value in where.items():
|
||||
query = query.where(self.metadata.tables[table].c[key] == value)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
if order_by:
|
||||
if not isinstance(order_by, list):
|
||||
raise ValueError(
|
||||
f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}"
|
||||
)
|
||||
for order in order_by:
|
||||
if not isinstance(order, tuple):
|
||||
raise ValueError(
|
||||
f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}"
|
||||
)
|
||||
name, order_type = order
|
||||
if order_type == "asc":
|
||||
query = query.order_by(self.metadata.tables[table].c[name].asc())
|
||||
elif order_type == "desc":
|
||||
query = query.order_by(self.metadata.tables[table].c[name].desc())
|
||||
else:
|
||||
raise ValueError(f"Invalid order '{order_type}' for column '{name}'")
|
||||
result = await conn.execute(query)
|
||||
if result.rowcount == 0:
|
||||
return []
|
||||
return [dict(row._mapping) for row in result]
|
||||
|
||||
async def fetch_one(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any] | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
rows = await self.fetch_all(table, where, limit=1, order_by=order_by)
|
||||
if not rows:
|
||||
return None
|
||||
return rows[0]
|
||||
|
||||
async def update(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
where: Mapping[str, Any],
|
||||
) -> None:
|
||||
if not where:
|
||||
raise ValueError("where is required for update")
|
||||
|
||||
async with self.engine.begin() as conn:
|
||||
stmt = self.metadata.tables[table].update()
|
||||
for key, value in where.items():
|
||||
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
||||
await conn.execute(stmt, data)
|
||||
await conn.commit()
|
||||
|
||||
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
|
||||
if not where:
|
||||
raise ValueError("where is required for delete")
|
||||
|
||||
async with self.engine.begin() as conn:
|
||||
stmt = self.metadata.tables[table].delete()
|
||||
for key, value in where.items():
|
||||
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
||||
await conn.execute(stmt)
|
||||
await conn.commit()
|
72
llama_stack/providers/utils/sqlstore/sqlstore.py
Normal file
72
llama_stack/providers/utils/sqlstore/sqlstore.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
|
||||
from .api import SqlStore
|
||||
|
||||
|
||||
class SqlStoreType(Enum):
|
||||
sqlite = "sqlite"
|
||||
postgres = "postgres"
|
||||
|
||||
|
||||
class SqliteSqlStoreConfig(BaseModel):
|
||||
type: Literal["sqlite"] = SqlStoreType.sqlite.value
|
||||
db_path: str = Field(
|
||||
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||
description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db",
|
||||
)
|
||||
|
||||
@property
|
||||
def engine_str(self) -> str:
|
||||
return "sqlite+aiosqlite:///" + Path(self.db_path).expanduser().as_posix()
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "sqlstore.db"):
|
||||
return cls(
|
||||
type="sqlite",
|
||||
db_path="${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
|
||||
)
|
||||
|
||||
# TODO: move this when we have a better way to specify dependencies with internal APIs
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return ["sqlalchemy[asyncio]"]
|
||||
|
||||
|
||||
class PostgresSqlStoreConfig(BaseModel):
|
||||
type: Literal["postgres"] = SqlStoreType.postgres.value
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
raise NotImplementedError("Postgres is not implemented yet")
|
||||
|
||||
|
||||
SqlStoreConfig = Annotated[
|
||||
SqliteSqlStoreConfig | PostgresSqlStoreConfig,
|
||||
Field(discriminator="type", default=SqlStoreType.sqlite.value),
|
||||
]
|
||||
|
||||
|
||||
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
|
||||
if config.type == SqlStoreType.sqlite.value:
|
||||
from .sqlite.sqlite import SqliteSqlStoreImpl
|
||||
|
||||
impl = SqliteSqlStoreImpl(config)
|
||||
elif config.type == SqlStoreType.postgres.value:
|
||||
raise NotImplementedError("Postgres is not implemented yet")
|
||||
else:
|
||||
raise ValueError(f"Unknown sqlstore type {config.type}")
|
||||
|
||||
return impl
|
|
@ -29,3 +29,5 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -96,6 +96,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: meta.llama3-1-8b-instruct-v1:0
|
||||
|
|
|
@ -29,3 +29,5 @@ distribution_spec:
|
|||
- remote::tavily-search
|
||||
- inline::rag-runtime
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -99,6 +99,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: llama3.1-8b
|
||||
|
|
|
@ -30,3 +30,5 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -99,6 +99,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
|
||||
|
|
|
@ -30,3 +30,6 @@ distribution_spec:
|
|||
- remote::tavily-search
|
||||
- inline::rag-runtime
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -99,6 +99,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
@ -95,6 +95,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
|
@ -67,6 +68,7 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
|
@ -105,6 +107,7 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"sqlite-vec",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
|
@ -145,6 +148,7 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
|
@ -184,6 +188,7 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
|
@ -221,6 +226,7 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
|
@ -259,6 +265,7 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
|
@ -297,6 +304,7 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
|
@ -335,6 +343,7 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"sqlite-vec",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
|
@ -379,6 +388,7 @@
|
|||
"scipy",
|
||||
"sentence-transformers",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"torch",
|
||||
"torchao==0.8.0",
|
||||
"torchvision",
|
||||
|
@ -414,6 +424,7 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn"
|
||||
|
@ -452,6 +463,7 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"torch",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
|
@ -490,6 +502,7 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"sqlite-vec",
|
||||
"together",
|
||||
"tqdm",
|
||||
|
@ -528,6 +541,7 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
|
@ -566,6 +580,7 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
|
@ -599,6 +614,7 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
|
@ -637,6 +653,7 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"sqlite-vec",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
|
@ -678,6 +695,7 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
|
@ -716,6 +734,7 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"together",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
|
@ -755,6 +774,7 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"sqlite-vec",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
|
@ -794,6 +814,7 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
|
@ -833,6 +854,7 @@
|
|||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlalchemy[asyncio]",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
|
|
|
@ -31,3 +31,6 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -111,6 +111,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
|
||||
|
|
|
@ -106,6 +106,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
|
||||
|
|
|
@ -26,3 +26,5 @@ distribution_spec:
|
|||
- remote::tavily-search
|
||||
- inline::rag-runtime
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -99,6 +99,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: groq/llama3-8b-8192
|
||||
|
|
|
@ -29,3 +29,6 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -107,6 +107,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
@ -102,6 +102,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
@ -30,3 +30,6 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -107,6 +107,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
@ -102,6 +102,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
@ -30,3 +30,5 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -111,6 +111,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: Llama-3.3-70B-Instruct
|
||||
|
|
|
@ -29,3 +29,6 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -117,6 +117,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
@ -107,6 +107,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
@ -24,3 +24,6 @@ distribution_spec:
|
|||
tool_runtime:
|
||||
- inline::rag-runtime
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -92,6 +92,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
@ -80,6 +80,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: meta/llama3-8b-instruct
|
||||
|
|
|
@ -32,3 +32,6 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
- remote::wolfram-alpha
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -112,6 +112,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
@ -110,6 +110,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
@ -33,3 +33,5 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -125,6 +125,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: openai/gpt-4o
|
||||
|
|
|
@ -31,3 +31,6 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -111,6 +111,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||
|
|
|
@ -106,6 +106,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||
|
|
|
@ -31,3 +31,6 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
- remote::wolfram-alpha
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -115,6 +115,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
@ -108,6 +108,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
@ -22,3 +22,5 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
- remote::wolfram-alpha
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -82,6 +82,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: sambanova/Meta-Llama-3.1-8B-Instruct
|
||||
|
|
|
@ -35,3 +35,5 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -133,6 +133,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: openai/gpt-4o
|
||||
|
|
|
@ -29,6 +29,7 @@ from llama_stack.distribution.distribution import get_provider_registry
|
|||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
||||
def get_model_registry(
|
||||
|
@ -117,6 +118,10 @@ class RunConfigSettings(BaseModel):
|
|||
__distro_dir__=f"~/.llama/distributions/{name}",
|
||||
db_name="registry.db",
|
||||
),
|
||||
inference_store=SqliteSqlStoreConfig.sample_run_config(
|
||||
__distro_dir__=f"~/.llama/distributions/{name}",
|
||||
db_name="inference_store.db",
|
||||
),
|
||||
models=self.default_models or [],
|
||||
shields=self.default_shields or [],
|
||||
tool_groups=self.default_tool_groups or [],
|
||||
|
@ -146,14 +151,20 @@ class DistributionTemplate(BaseModel):
|
|||
available_models_by_provider: dict[str, list[ProviderModelEntry]] | None = None
|
||||
|
||||
def build_config(self) -> BuildConfig:
|
||||
additional_pip_packages: list[str] = []
|
||||
for run_config in self.run_configs.values():
|
||||
run_config_ = run_config.run_config(self.name, self.providers, self.container_image)
|
||||
if run_config_.inference_store:
|
||||
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
|
||||
|
||||
return BuildConfig(
|
||||
name=self.name,
|
||||
distribution_spec=DistributionSpec(
|
||||
description=self.description,
|
||||
container_image=self.container_image,
|
||||
providers=self.providers,
|
||||
),
|
||||
image_type="conda", # default to conda, can be overridden
|
||||
additional_pip_packages=additional_pip_packages,
|
||||
)
|
||||
|
||||
def generate_markdown_docs(self) -> str:
|
||||
|
|
|
@ -30,3 +30,6 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -102,6 +102,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
@ -101,6 +101,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
@ -31,3 +31,6 @@ distribution_spec:
|
|||
- remote::model-context-protocol
|
||||
- remote::wolfram-alpha
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -111,6 +111,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo
|
||||
|
|
|
@ -106,6 +106,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo
|
||||
|
|
|
@ -35,3 +35,5 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -135,6 +135,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: openai/gpt-4o
|
||||
|
|
|
@ -30,3 +30,5 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -106,6 +106,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
|
|
|
@ -28,3 +28,5 @@ distribution_spec:
|
|||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -103,6 +103,9 @@ providers:
|
|||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: meta-llama/llama-3-3-70b-instruct
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue