mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 03:49:40 +00:00
Merge branch 'main' into feature/dpo-training
This commit is contained in:
commit
b68b818539
265 changed files with 10254 additions and 7796 deletions
|
|
@ -10,6 +10,7 @@ import re
|
|||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
|
@ -911,8 +912,16 @@ async def load_data_from_url(url: str) -> str:
|
|||
|
||||
|
||||
async def get_raw_document_text(document: Document) -> str:
|
||||
if not document.mime_type.startswith("text/"):
|
||||
# Handle deprecated text/yaml mime type with warning
|
||||
if document.mime_type == "text/yaml":
|
||||
warnings.warn(
|
||||
"The 'text/yaml' MIME type is deprecated. Please use 'application/yaml' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
elif not (document.mime_type.startswith("text/") or document.mime_type == "application/yaml"):
|
||||
raise ValueError(f"Unexpected document mime type: {document.mime_type}")
|
||||
|
||||
if isinstance(document.content, URL):
|
||||
return await load_data_from_url(document.content.uri)
|
||||
elif isinstance(document.content, str):
|
||||
|
|
|
|||
|
|
@ -128,6 +128,11 @@ class AgentPersistence:
|
|||
except Exception as e:
|
||||
log.error(f"Error parsing turn: {e}")
|
||||
continue
|
||||
|
||||
# The kvstore does not guarantee order, so we sort by started_at
|
||||
# to ensure consistent ordering of turns.
|
||||
turns.sort(key=lambda t: t.started_at)
|
||||
|
||||
return turns
|
||||
|
||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.datatypes import AccessRule, Api
|
||||
|
||||
from .config import LocalfsFilesImplConfig
|
||||
from .files import LocalfsFilesImpl
|
||||
|
|
@ -14,7 +14,7 @@ from .files import LocalfsFilesImpl
|
|||
__all__ = ["LocalfsFilesImpl", "LocalfsFilesImplConfig"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: LocalfsFilesImplConfig, deps: dict[Api, Any]):
|
||||
impl = LocalfsFilesImpl(config)
|
||||
async def get_provider_impl(config: LocalfsFilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
|
||||
impl = LocalfsFilesImpl(config, policy)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -19,16 +19,19 @@ from llama_stack.apis.files import (
|
|||
OpenAIFileObject,
|
||||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import AccessRule
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
|
||||
from .config import LocalfsFilesImplConfig
|
||||
|
||||
|
||||
class LocalfsFilesImpl(Files):
|
||||
def __init__(self, config: LocalfsFilesImplConfig) -> None:
|
||||
def __init__(self, config: LocalfsFilesImplConfig, policy: list[AccessRule]) -> None:
|
||||
self.config = config
|
||||
self.sql_store: SqlStore | None = None
|
||||
self.policy = policy
|
||||
self.sql_store: AuthorizedSqlStore | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the files provider by setting up storage directory and metadata database."""
|
||||
|
|
@ -37,7 +40,7 @@ class LocalfsFilesImpl(Files):
|
|||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize SQL store for metadata
|
||||
self.sql_store = sqlstore_impl(self.config.metadata_store)
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store))
|
||||
await self.sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
|
|
@ -126,6 +129,7 @@ class LocalfsFilesImpl(Files):
|
|||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
policy=self.policy,
|
||||
where=where_conditions if where_conditions else None,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
|
|
@ -156,7 +160,7 @@ class LocalfsFilesImpl(Files):
|
|||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||
if not row:
|
||||
raise ValueError(f"File with id {file_id} not found")
|
||||
|
||||
|
|
@ -174,7 +178,7 @@ class LocalfsFilesImpl(Files):
|
|||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||
if not row:
|
||||
raise ValueError(f"File with id {file_id} not found")
|
||||
|
||||
|
|
@ -197,7 +201,7 @@ class LocalfsFilesImpl(Files):
|
|||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
# Get file metadata
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||
if not row:
|
||||
raise ValueError(f"File with id {file_id} not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -102,6 +102,12 @@ class MetaReferenceInferenceImpl(
|
|||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
return None
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
|
|
@ -41,6 +42,8 @@ class SentenceTransformersInferenceImpl(
|
|||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
__provider_id__: str
|
||||
|
||||
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
|
|
@ -50,6 +53,22 @@ class SentenceTransformersInferenceImpl(
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
return [
|
||||
Model(
|
||||
identifier="all-MiniLM-L6-v2",
|
||||
provider_resource_id="all-MiniLM-L6-v2",
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
return model
|
||||
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import VLLMConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: VLLMConfig, _deps: dict[str, Any]):
|
||||
from .vllm import VLLMInferenceImpl
|
||||
|
||||
impl = VLLMInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VLLMConfig(BaseModel):
|
||||
"""Configuration for the vLLM inference provider.
|
||||
|
||||
Note that the model name is no longer part of this static configuration.
|
||||
You can bind an instance of this provider to a specific model with the
|
||||
``models.register()`` API call."""
|
||||
|
||||
tensor_parallel_size: int = Field(
|
||||
default=1,
|
||||
description="Number of tensor parallel replicas (number of GPUs to use).",
|
||||
)
|
||||
max_tokens: int = Field(
|
||||
default=4096,
|
||||
description="Maximum number of tokens to generate.",
|
||||
)
|
||||
max_model_len: int = Field(default=4096, description="Maximum context length to use during serving.")
|
||||
max_num_seqs: int = Field(default=4, description="Maximum parallel batch size for generation.")
|
||||
enforce_eager: bool = Field(
|
||||
default=False,
|
||||
description="Whether to use eager mode for inference (otherwise cuda graphs are used).",
|
||||
)
|
||||
gpu_memory_utilization: float = Field(
|
||||
default=0.3,
|
||||
description=(
|
||||
"How much GPU memory will be allocated when this provider has finished "
|
||||
"loading, including memory that was already allocated before loading."
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:=1}",
|
||||
"max_tokens": "${env.MAX_TOKENS:=4096}",
|
||||
"max_model_len": "${env.MAX_MODEL_LEN:=4096}",
|
||||
"max_num_seqs": "${env.MAX_NUM_SEQS:=4}",
|
||||
"enforce_eager": "${env.ENFORCE_EAGER:=False}",
|
||||
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:=0.3}",
|
||||
}
|
||||
|
|
@ -1,170 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import vllm
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
GrammarResponseFormat,
|
||||
JsonSchemaResponseFormat,
|
||||
Message,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
)
|
||||
|
||||
###############################################################################
|
||||
# This file contains OpenAI compatibility code that is currently only used
|
||||
# by the inline vLLM connector. Some or all of this code may be moved to a
|
||||
# central location at a later date.
|
||||
|
||||
|
||||
def _merge_context_into_content(message: Message) -> Message: # type: ignore
|
||||
"""
|
||||
Merge the ``context`` field of a Llama Stack ``Message`` object into
|
||||
the content field for compabilitiy with OpenAI-style APIs.
|
||||
|
||||
Generates a content string that emulates the current behavior
|
||||
of ``llama_models.llama3.api.chat_format.encode_message()``.
|
||||
|
||||
:param message: Message that may include ``context`` field
|
||||
|
||||
:returns: A version of ``message`` with any context merged into the
|
||||
``content`` field.
|
||||
"""
|
||||
if not isinstance(message, UserMessage): # Separate type check for linter
|
||||
return message
|
||||
if message.context is None:
|
||||
return message
|
||||
return UserMessage(
|
||||
role=message.role,
|
||||
# Emumate llama_models.llama3.api.chat_format.encode_message()
|
||||
content=message.content + "\n\n" + message.context,
|
||||
context=None,
|
||||
)
|
||||
|
||||
|
||||
def _llama_stack_tools_to_openai_tools(
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
) -> list[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
|
||||
"""
|
||||
Convert the list of available tools from Llama Stack's format to vLLM's
|
||||
version of OpenAI's format.
|
||||
"""
|
||||
if tools is None:
|
||||
return []
|
||||
|
||||
result = []
|
||||
for t in tools:
|
||||
if isinstance(t.tool_name, BuiltinTool):
|
||||
raise NotImplementedError("Built-in tools not yet implemented")
|
||||
if t.parameters is None:
|
||||
parameters = None
|
||||
else: # if t.parameters is not None
|
||||
# Convert the "required" flags to a list of required params
|
||||
required_params = [k for k, v in t.parameters.items() if v.required]
|
||||
parameters = {
|
||||
"type": "object", # Mystery value that shows up in OpenAI docs
|
||||
"properties": {
|
||||
k: {"type": v.param_type, "description": v.description} for k, v in t.parameters.items()
|
||||
},
|
||||
"required": required_params,
|
||||
}
|
||||
|
||||
function_def = vllm.entrypoints.openai.protocol.FunctionDefinition(
|
||||
name=t.tool_name, description=t.description, parameters=parameters
|
||||
)
|
||||
|
||||
# Every tool definition is double-boxed in a ChatCompletionToolsParam
|
||||
result.append(vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(function=function_def))
|
||||
return result
|
||||
|
||||
|
||||
async def llama_stack_chat_completion_to_openai_chat_completion_dict(
|
||||
request: ChatCompletionRequest,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert a chat completion request in Llama Stack format into an
|
||||
equivalent set of arguments to pass to an OpenAI-compatible
|
||||
chat completions API.
|
||||
|
||||
:param request: Bundled request parameters in Llama Stack format.
|
||||
|
||||
:returns: Dictionary of key-value pairs to use as an initializer
|
||||
for a dataclass or to be converted directly to JSON and sent
|
||||
over the wire.
|
||||
"""
|
||||
|
||||
converted_messages = [
|
||||
# This mystery async call makes the parent function also be async
|
||||
await convert_message_to_openai_dict(_merge_context_into_content(m), download=True)
|
||||
for m in request.messages
|
||||
]
|
||||
converted_tools = _llama_stack_tools_to_openai_tools(request.tools)
|
||||
|
||||
# Llama will try to use built-in tools with no tool catalog, so don't enable
|
||||
# tool choice unless at least one tool is enabled.
|
||||
converted_tool_choice = "none"
|
||||
if (
|
||||
request.tool_config is not None
|
||||
and request.tool_config.tool_choice == ToolChoice.auto
|
||||
and request.tools is not None
|
||||
and len(request.tools) > 0
|
||||
):
|
||||
converted_tool_choice = "auto"
|
||||
|
||||
# TODO: Figure out what to do with the tool_prompt_format argument.
|
||||
# Other connectors appear to drop it quietly.
|
||||
|
||||
# Use Llama Stack shared code to translate sampling parameters.
|
||||
sampling_options = get_sampling_options(request.sampling_params)
|
||||
|
||||
# get_sampling_options() translates repetition penalties to an option that
|
||||
# OpenAI's APIs don't know about.
|
||||
# vLLM's OpenAI-compatible API also handles repetition penalties wrong.
|
||||
# For now, translate repetition penalties into a format that vLLM's broken
|
||||
# API will handle correctly. Two wrongs make a right...
|
||||
if "repeat_penalty" in sampling_options:
|
||||
del sampling_options["repeat_penalty"]
|
||||
if request.sampling_params.repetition_penalty is not None and request.sampling_params.repetition_penalty != 1.0:
|
||||
sampling_options["repetition_penalty"] = request.sampling_params.repetition_penalty
|
||||
|
||||
# Convert a single response format into four different parameters, per
|
||||
# the OpenAI spec
|
||||
guided_decoding_options = dict()
|
||||
if request.response_format is None:
|
||||
# Use defaults
|
||||
pass
|
||||
elif isinstance(request.response_format, JsonSchemaResponseFormat):
|
||||
guided_decoding_options["guided_json"] = request.response_format.json_schema
|
||||
elif isinstance(request.response_format, GrammarResponseFormat):
|
||||
guided_decoding_options["guided_grammar"] = request.response_format.bnf
|
||||
else:
|
||||
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(request.response_format)}'")
|
||||
|
||||
logprob_options = dict()
|
||||
if request.logprobs is not None:
|
||||
logprob_options["logprobs"] = request.logprobs.top_k
|
||||
|
||||
# Marshall together all the arguments for a ChatCompletionRequest
|
||||
request_options = {
|
||||
"model": request.model,
|
||||
"messages": converted_messages,
|
||||
"tools": converted_tools,
|
||||
"tool_choice": converted_tool_choice,
|
||||
"stream": request.stream,
|
||||
**sampling_options,
|
||||
**guided_decoding_options,
|
||||
**logprob_options,
|
||||
}
|
||||
|
||||
return request_options
|
||||
|
|
@ -1,811 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
|
||||
# These vLLM modules contain names that overlap with Llama Stack names, so we import
|
||||
# fully-qualified names
|
||||
import vllm.entrypoints.openai.protocol
|
||||
import vllm.sampling_params
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GrammarResponseFormat,
|
||||
Inference,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
TokenLogProbs,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama import sku_list
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.remote.inference.vllm.vllm import build_hf_repo_model_entries
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
ModelsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_stop_reason,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import VLLMConfig
|
||||
from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict
|
||||
|
||||
# Map from Hugging Face model architecture name to appropriate tool parser.
|
||||
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of
|
||||
# available parsers.
|
||||
# TODO: Expand this list
|
||||
CONFIG_TYPE_TO_TOOL_PARSER = {
|
||||
"GraniteConfig": "granite",
|
||||
"MllamaConfig": "llama3_json",
|
||||
"LlamaConfig": "llama3_json",
|
||||
}
|
||||
DEFAULT_TOOL_PARSER = "pythonic"
|
||||
|
||||
|
||||
logger = get_logger(__name__, category="inference")
|
||||
|
||||
|
||||
def _random_uuid_str() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
def _response_format_to_guided_decoding_params(
|
||||
response_format: ResponseFormat | None, # type: ignore
|
||||
) -> vllm.sampling_params.GuidedDecodingParams:
|
||||
"""
|
||||
Translate constrained decoding parameters from Llama Stack's format to vLLM's format.
|
||||
|
||||
:param response_format: Llama Stack version of constrained decoding info. Can be ``None``,
|
||||
indicating no constraints.
|
||||
:returns: The equivalent dataclass object for the low-level inference layer of vLLM.
|
||||
"""
|
||||
if response_format is None:
|
||||
# As of vLLM 0.6.3, the default constructor for GuidedDecodingParams() returns an invalid
|
||||
# value that crashes the executor on some code paths. Use ``None`` instead.
|
||||
return None
|
||||
|
||||
# Llama Stack currently implements fewer types of constrained decoding than vLLM does.
|
||||
# Translate the types that exist and detect if Llama Stack adds new ones.
|
||||
if isinstance(response_format, JsonSchemaResponseFormat):
|
||||
return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema)
|
||||
elif isinstance(response_format, GrammarResponseFormat):
|
||||
# BNF grammar.
|
||||
# Llama Stack uses the parse tree of the grammar, while vLLM uses the string
|
||||
# representation of the grammar.
|
||||
raise TypeError(
|
||||
"Constrained decoding with BNF grammars is not currently implemented, because the "
|
||||
"reference implementation does not implement it."
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'")
|
||||
|
||||
|
||||
def _convert_sampling_params(
|
||||
sampling_params: SamplingParams | None,
|
||||
response_format: ResponseFormat | None, # type: ignore
|
||||
log_prob_config: LogProbConfig | None,
|
||||
) -> vllm.SamplingParams:
|
||||
"""Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's
|
||||
format."""
|
||||
# In the absence of provided config values, use Llama Stack defaults as encoded in the Llama
|
||||
# Stack dataclasses. These defaults are different from vLLM's defaults.
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if log_prob_config is None:
|
||||
log_prob_config = LogProbConfig()
|
||||
|
||||
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
|
||||
if sampling_params.strategy.top_k == 0:
|
||||
# vLLM treats "k" differently for top-k sampling
|
||||
vllm_top_k = -1
|
||||
else:
|
||||
vllm_top_k = sampling_params.strategy.top_k
|
||||
else:
|
||||
vllm_top_k = -1
|
||||
|
||||
if isinstance(sampling_params.strategy, TopPSamplingStrategy):
|
||||
vllm_top_p = sampling_params.strategy.top_p
|
||||
# Llama Stack only allows temperature with top-P.
|
||||
vllm_temperature = sampling_params.strategy.temperature
|
||||
else:
|
||||
vllm_top_p = 1.0
|
||||
vllm_temperature = 0.0
|
||||
|
||||
# vLLM allows top-p and top-k at the same time.
|
||||
vllm_sampling_params = vllm.SamplingParams.from_optional(
|
||||
max_tokens=(None if sampling_params.max_tokens == 0 else sampling_params.max_tokens),
|
||||
temperature=vllm_temperature,
|
||||
top_p=vllm_top_p,
|
||||
top_k=vllm_top_k,
|
||||
repetition_penalty=sampling_params.repetition_penalty,
|
||||
guided_decoding=_response_format_to_guided_decoding_params(response_format),
|
||||
logprobs=log_prob_config.top_k,
|
||||
)
|
||||
return vllm_sampling_params
|
||||
|
||||
|
||||
class VLLMInferenceImpl(
|
||||
Inference,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
"""
|
||||
vLLM-based inference model adapter for Llama Stack with support for multiple models.
|
||||
|
||||
Requires the configuration parameters documented in the :class:`VllmConfig2` class.
|
||||
"""
|
||||
|
||||
config: VLLMConfig
|
||||
register_helper: ModelRegistryHelper
|
||||
model_ids: set[str]
|
||||
resolved_model_id: str | None
|
||||
engine: AsyncLLMEngine | None
|
||||
chat: OpenAIServingChat | None
|
||||
is_meta_llama_model: bool
|
||||
|
||||
def __init__(self, config: VLLMConfig):
|
||||
self.config = config
|
||||
logger.info(f"Config is: {self.config}")
|
||||
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
# The following are initialized when paths are bound to this provider
|
||||
self.resolved_model_id = None
|
||||
self.model_ids = set()
|
||||
self.engine = None
|
||||
self.chat = None
|
||||
self.is_meta_llama_model = False
|
||||
|
||||
###########################################################################
|
||||
# METHODS INHERITED FROM IMPLICIT BASE CLASS.
|
||||
# TODO: Make this class inherit from the new base class ProviderBase once that class exists.
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
Callback that is invoked through many levels of indirection during provider class
|
||||
instantiation, sometime after when __init__() is called and before any model registration
|
||||
methods or methods connected to a REST API are called.
|
||||
|
||||
It's not clear what assumptions the class can make about the platform's initialization
|
||||
state here that can't be made during __init__(), and vLLM can't be started until we know
|
||||
what model it's supposed to be serving, so nothing happens here currently.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.info(f"Shutting down inline vLLM inference provider {self}.")
|
||||
if self.engine is not None:
|
||||
self.engine.shutdown_background_loop()
|
||||
self.engine = None
|
||||
self.chat = None
|
||||
self.model_ids = set()
|
||||
self.resolved_model_id = None
|
||||
|
||||
###########################################################################
|
||||
# METHODS INHERITED FROM ModelsProtocolPrivate INTERFACE
|
||||
|
||||
# Note that the return type of the superclass method is WRONG
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
"""
|
||||
Callback that is called when the server associates an inference endpoint with an
|
||||
inference provider.
|
||||
|
||||
:param model: Object that encapsulates parameters necessary for identifying a specific
|
||||
LLM.
|
||||
|
||||
:returns: The input ``Model`` object. It may or may not be permissible to change fields
|
||||
before returning this object.
|
||||
"""
|
||||
logger.debug(f"In register_model({model})")
|
||||
|
||||
# First attempt to interpret the model coordinates as a Llama model name
|
||||
resolved_llama_model = sku_list.resolve_model(model.provider_model_id)
|
||||
if resolved_llama_model is not None:
|
||||
# Load from Hugging Face repo into default local cache dir
|
||||
model_id_for_vllm = resolved_llama_model.huggingface_repo
|
||||
|
||||
# Detect a genuine Meta Llama model to trigger Meta-specific preprocessing.
|
||||
# Don't set self.is_meta_llama_model until we actually load the model.
|
||||
is_meta_llama_model = True
|
||||
else: # if resolved_llama_model is None
|
||||
# Not a Llama model name. Pass the model id through to vLLM's loader
|
||||
model_id_for_vllm = model.provider_model_id
|
||||
is_meta_llama_model = False
|
||||
|
||||
if self.resolved_model_id is not None:
|
||||
if model_id_for_vllm != self.resolved_model_id:
|
||||
raise ValueError(
|
||||
f"Attempted to serve two LLMs (ids '{self.resolved_model_id}') and "
|
||||
f"'{model_id_for_vllm}') from one copy of provider '{self}'. Use multiple "
|
||||
f"copies of the provider instead."
|
||||
)
|
||||
else:
|
||||
# Model already loaded
|
||||
logger.info(
|
||||
f"Requested id {model} resolves to {model_id_for_vllm}, which is already loaded. Continuing."
|
||||
)
|
||||
self.model_ids.add(model.model_id)
|
||||
return model
|
||||
|
||||
logger.info(f"Requested id {model} resolves to {model_id_for_vllm}. Loading {model_id_for_vllm}.")
|
||||
if is_meta_llama_model:
|
||||
logger.info(f"Model {model_id_for_vllm} is a Meta Llama model.")
|
||||
self.is_meta_llama_model = is_meta_llama_model
|
||||
|
||||
# If we get here, this is the first time registering a model.
|
||||
# Preload so that the first inference request won't time out.
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model_id_for_vllm,
|
||||
tokenizer=model_id_for_vllm,
|
||||
tensor_parallel_size=self.config.tensor_parallel_size,
|
||||
enforce_eager=self.config.enforce_eager,
|
||||
gpu_memory_utilization=self.config.gpu_memory_utilization,
|
||||
max_num_seqs=self.config.max_num_seqs,
|
||||
max_model_len=self.config.max_model_len,
|
||||
)
|
||||
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
# vLLM currently requires the user to specify the tool parser manually. To choose a tool
|
||||
# parser, we need to determine what model architecture is being used. For now, we infer
|
||||
# that information from what config class the model uses.
|
||||
low_level_model_config = self.engine.engine.get_model_config()
|
||||
hf_config = low_level_model_config.hf_config
|
||||
hf_config_class_name = hf_config.__class__.__name__
|
||||
if hf_config_class_name in CONFIG_TYPE_TO_TOOL_PARSER:
|
||||
tool_parser = CONFIG_TYPE_TO_TOOL_PARSER[hf_config_class_name]
|
||||
else:
|
||||
# No info -- choose a default so we can at least attempt tool
|
||||
# use.
|
||||
tool_parser = DEFAULT_TOOL_PARSER
|
||||
logger.debug(f"{hf_config_class_name=}")
|
||||
logger.debug(f"{tool_parser=}")
|
||||
|
||||
# Wrap the lower-level engine in an OpenAI-compatible chat API
|
||||
model_config = await self.engine.get_model_config()
|
||||
self.chat = OpenAIServingChat(
|
||||
engine_client=self.engine,
|
||||
model_config=model_config,
|
||||
models=OpenAIServingModels(
|
||||
engine_client=self.engine,
|
||||
model_config=model_config,
|
||||
base_model_paths=[
|
||||
# The layer below us will only see resolved model IDs
|
||||
BaseModelPath(model_id_for_vllm, model_id_for_vllm)
|
||||
],
|
||||
),
|
||||
response_role="assistant",
|
||||
request_logger=None, # Use default logging
|
||||
chat_template=None, # Use default template from model checkpoint
|
||||
enable_auto_tools=True,
|
||||
tool_parser=tool_parser,
|
||||
chat_template_content_format="auto",
|
||||
)
|
||||
self.resolved_model_id = model_id_for_vllm
|
||||
self.model_ids.add(model.model_id)
|
||||
|
||||
logger.info(f"Finished preloading model: {model_id_for_vllm}")
|
||||
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
"""
|
||||
Callback that is called when the server removes an inference endpoint from an inference
|
||||
provider.
|
||||
|
||||
:param model_id: The same external ID that the higher layers of the stack previously passed
|
||||
to :func:`register_model()`
|
||||
"""
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"Attempted to unregister model ID '{model_id}', but that ID is not registered to this provider."
|
||||
)
|
||||
self.model_ids.remove(model_id)
|
||||
|
||||
if len(self.model_ids) == 0:
|
||||
# Last model was just unregistered. Shut down the connection to vLLM and free up
|
||||
# resources.
|
||||
# Note that this operation may cause in-flight chat completion requests on the
|
||||
# now-unregistered model to return errors.
|
||||
self.resolved_model_id = None
|
||||
self.chat = None
|
||||
self.engine.shutdown_background_loop()
|
||||
self.engine = None
|
||||
|
||||
###########################################################################
|
||||
# METHODS INHERITED FROM Inference INTERFACE
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||
)
|
||||
if not isinstance(content, str):
|
||||
raise NotImplementedError("Multimodal input not currently supported")
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
|
||||
|
||||
logger.debug(f"{converted_sampling_params=}")
|
||||
|
||||
if stream:
|
||||
return self._streaming_completion(content, converted_sampling_params)
|
||||
else:
|
||||
streaming_result = None
|
||||
async for _ in self._streaming_completion(content, converted_sampling_params):
|
||||
pass
|
||||
return CompletionResponse(
|
||||
content=streaming_result.delta,
|
||||
stop_reason=streaming_result.stop_reason,
|
||||
logprobs=streaming_result.logprobs,
|
||||
)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message], # type: ignore
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None, # type: ignore
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||
sampling_params = sampling_params or SamplingParams()
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||
)
|
||||
|
||||
# Convert to Llama Stack internal format for consistency
|
||||
request = ChatCompletionRequest(
|
||||
model=self.resolved_model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
if self.is_meta_llama_model:
|
||||
# Bypass vLLM chat templating layer for Meta Llama models, because the
|
||||
# templating layer in Llama Stack currently produces better results.
|
||||
logger.debug(
|
||||
f"Routing {self.resolved_model_id} chat completion through "
|
||||
f"Llama Stack's templating layer instead of vLLM's."
|
||||
)
|
||||
return await self._chat_completion_for_meta_llama(request)
|
||||
|
||||
logger.debug(f"{self.resolved_model_id} is not a Meta Llama model")
|
||||
|
||||
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass.
|
||||
# Note that this dataclass has the same name as a similar dataclass in Llama Stack.
|
||||
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(request)
|
||||
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options)
|
||||
|
||||
logger.debug(f"Converted request: {chat_completion_request}")
|
||||
|
||||
vllm_result = await self.chat.create_chat_completion(chat_completion_request)
|
||||
logger.debug(f"Result from vLLM: {vllm_result}")
|
||||
if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse):
|
||||
raise ValueError(f"Error from vLLM layer: {vllm_result}")
|
||||
|
||||
# Return type depends on "stream" argument
|
||||
if stream:
|
||||
if not isinstance(vllm_result, AsyncGenerator):
|
||||
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
|
||||
# vLLM client returns a stream of strings, which need to be parsed.
|
||||
# Stream comes in the form of an async generator.
|
||||
return self._convert_streaming_results(vllm_result)
|
||||
else:
|
||||
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
|
||||
raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call")
|
||||
return self._convert_non_streaming_results(vllm_result)
|
||||
|
||||
###########################################################################
|
||||
# INTERNAL METHODS
|
||||
|
||||
async def _streaming_completion(
|
||||
self, content: str, sampling_params: vllm.SamplingParams
|
||||
) -> AsyncIterator[CompletionResponseStreamChunk]:
|
||||
"""Internal implementation of :func:`completion()` API for the streaming case. Assumes
|
||||
that arguments have been validated upstream.
|
||||
|
||||
:param content: Must be a string
|
||||
:param sampling_params: Paramters from public API's ``response_format``
|
||||
and ``sampling_params`` arguments, converted to VLLM format
|
||||
"""
|
||||
# We run agains the vLLM generate() call directly instead of using the OpenAI-compatible
|
||||
# layer, because doing so simplifies the code here.
|
||||
|
||||
# The vLLM engine requires a unique identifier for each call to generate()
|
||||
request_id = _random_uuid_str()
|
||||
|
||||
# The vLLM generate() API is streaming-only and returns an async generator.
|
||||
# The generator returns objects of type vllm.RequestOutput.
|
||||
results_generator = self.engine.generate(content, sampling_params, request_id)
|
||||
|
||||
# Need to know the model's EOS token ID for the conversion code below.
|
||||
# AsyncLLMEngine is a wrapper around LLMEngine, and the tokenizer is only available if
|
||||
# we drill down to the LLMEngine inside the AsyncLLMEngine.
|
||||
# Similarly, the tokenizer in an LLMEngine is a wrapper around a BaseTokenizerGroup,
|
||||
# and we need to drill down to the Hugging Face tokenizer inside the BaseTokenizerGroup.
|
||||
llm_engine = self.engine.engine
|
||||
tokenizer_group = llm_engine.tokenizer
|
||||
eos_token_id = tokenizer_group.tokenizer.eos_token_id
|
||||
|
||||
request_output: vllm.RequestOutput = None
|
||||
async for request_output in results_generator:
|
||||
# Check for weird inference failures
|
||||
if request_output.outputs is None or len(request_output.outputs) == 0:
|
||||
# This case also should never happen
|
||||
raise ValueError("Inference produced empty result")
|
||||
|
||||
# If we get here, then request_output contains the final output of the generate() call.
|
||||
# The result may include multiple alternate outputs, but Llama Stack APIs only allow
|
||||
# us to return one.
|
||||
output: vllm.CompletionOutput = request_output.outputs[0]
|
||||
completion_string = output.text
|
||||
|
||||
# Convert logprobs from vLLM's format to Llama Stack's format
|
||||
logprobs = [
|
||||
TokenLogProbs(logprobs_by_token={v.decoded_token: v.logprob for _, v in logprob_dict.items()})
|
||||
for logprob_dict in output.logprobs
|
||||
]
|
||||
|
||||
# The final output chunk should be labeled with the reason that the overall generate()
|
||||
# call completed.
|
||||
logger.debug(f"{output.stop_reason=}; {type(output.stop_reason)=}")
|
||||
if output.stop_reason is None:
|
||||
stop_reason = None # Still going
|
||||
elif output.stop_reason == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif output.stop_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
elif isinstance(output.stop_reason, int):
|
||||
# If the model config specifies multiple end-of-sequence tokens, then vLLM
|
||||
# will return the token ID of the EOS token in the stop_reason field.
|
||||
stop_reason = StopReason.end_of_turn
|
||||
else:
|
||||
raise ValueError(f"Unrecognized stop reason '{output.stop_reason}'")
|
||||
|
||||
# vLLM's protocol outputs the stop token, then sets end of message on the next step for
|
||||
# some reason.
|
||||
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs)
|
||||
|
||||
# Llama Stack requires that the last chunk have a stop reason, but vLLM doesn't always
|
||||
# provide one if it runs out of tokens.
|
||||
if stop_reason is None:
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=completion_string,
|
||||
stop_reason=StopReason.out_of_tokens,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
def _convert_non_streaming_results(
|
||||
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible API into an
|
||||
equivalent Llama Stack object.
|
||||
|
||||
The result from vLLM's non-streaming API is a dataclass with the same name as the Llama
|
||||
Stack ChatCompletionResponse dataclass, but with more and different field names. We ignore
|
||||
the fields that aren't currently present in the Llama Stack dataclass.
|
||||
"""
|
||||
|
||||
# There may be multiple responses, but we can only pass through the first one.
|
||||
if len(vllm_result.choices) == 0:
|
||||
raise ValueError("Don't know how to convert response object without any responses")
|
||||
vllm_message = vllm_result.choices[0].message
|
||||
vllm_finish_reason = vllm_result.choices[0].finish_reason
|
||||
|
||||
converted_message = CompletionMessage(
|
||||
role=vllm_message.role,
|
||||
# Llama Stack API won't accept None for content field.
|
||||
content=("" if vllm_message.content is None else vllm_message.content),
|
||||
stop_reason=get_stop_reason(vllm_finish_reason),
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id=t.id,
|
||||
tool_name=t.function.name,
|
||||
# vLLM function args come back as a string. Llama Stack expects JSON.
|
||||
arguments=json.loads(t.function.arguments),
|
||||
arguments_json=t.function.arguments,
|
||||
)
|
||||
for t in vllm_message.tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
# TODO: Convert logprobs
|
||||
|
||||
logger.debug(f"Converted message: {converted_message}")
|
||||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=converted_message,
|
||||
)
|
||||
|
||||
async def _chat_completion_for_meta_llama(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
"""
|
||||
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
|
||||
chat template instead of using vLLM's version of that template. The Llama Stack version
|
||||
of the chat template currently produces more reliable outputs.
|
||||
|
||||
Once vLLM's support for Meta Llama models has matured more, we should consider routing
|
||||
Meta Llama requests through the vLLM chat completions API instead of using this method.
|
||||
"""
|
||||
formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
# Note that this function call modifies `request` in place.
|
||||
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id)
|
||||
|
||||
model_id = list(self.model_ids)[0] # Any model ID will do here
|
||||
completion_response_or_iterator = await self.completion(
|
||||
model_id=model_id,
|
||||
content=prompt,
|
||||
sampling_params=request.sampling_params,
|
||||
response_format=request.response_format,
|
||||
stream=request.stream,
|
||||
logprobs=request.logprobs,
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
if not isinstance(completion_response_or_iterator, AsyncIterator):
|
||||
raise TypeError(
|
||||
f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request."
|
||||
)
|
||||
return self._chat_completion_for_meta_llama_streaming(completion_response_or_iterator, request)
|
||||
|
||||
# elsif not request.stream:
|
||||
if not isinstance(completion_response_or_iterator, CompletionResponse):
|
||||
raise TypeError(
|
||||
f"Received unexpected result type {type(completion_response_or_iterator)}for non-streaming request."
|
||||
)
|
||||
completion_response: CompletionResponse = completion_response_or_iterator
|
||||
raw_message = formatter.decode_assistant_message_from_content(
|
||||
completion_response.content, completion_response.stop_reason
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=raw_message.content,
|
||||
stop_reason=raw_message.stop_reason,
|
||||
tool_calls=raw_message.tool_calls,
|
||||
),
|
||||
logprobs=completion_response.logprobs,
|
||||
)
|
||||
|
||||
async def _chat_completion_for_meta_llama_streaming(
|
||||
self, results_iterator: AsyncIterator, request: ChatCompletionRequest
|
||||
) -> AsyncIterator:
|
||||
"""
|
||||
Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate
|
||||
method to keep asyncio happy.
|
||||
"""
|
||||
|
||||
# Convert to OpenAI format, then use shared code to convert to Llama Stack format.
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
chunk: CompletionResponseStreamChunk # Make Pylance happy
|
||||
last_text_len = 0
|
||||
async for chunk in results_iterator:
|
||||
if chunk.stop_reason == StopReason.end_of_turn:
|
||||
finish_reason = "stop"
|
||||
elif chunk.stop_reason == StopReason.end_of_message:
|
||||
finish_reason = "eos"
|
||||
elif chunk.stop_reason == StopReason.out_of_tokens:
|
||||
finish_reason = "length"
|
||||
else:
|
||||
finish_reason = None
|
||||
|
||||
# Convert delta back to an actual delta
|
||||
text_delta = chunk.delta[last_text_len:]
|
||||
last_text_len = len(chunk.delta)
|
||||
|
||||
logger.debug(f"{text_delta=}; {finish_reason=}")
|
||||
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[OpenAICompatCompletionChoice(finish_reason=finish_reason, text=text_delta)]
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
logger.debug(f"Returning chunk: {chunk}")
|
||||
yield chunk
|
||||
|
||||
async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator:
|
||||
"""
|
||||
Subroutine that wraps the streaming outputs of vLLM's OpenAI-compatible
|
||||
API into a second async iterator that returns Llama Stack objects.
|
||||
|
||||
:param vllm_result: Stream of strings that need to be parsed
|
||||
"""
|
||||
# Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up
|
||||
# those chunks and output them at the end.
|
||||
# This data structure holds the current set of partial tool calls.
|
||||
index_to_tool_call: dict[int, dict] = dict()
|
||||
|
||||
# The Llama Stack event stream must always start with a start event. Use an empty one to
|
||||
# simplify logic below
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=None,
|
||||
)
|
||||
)
|
||||
|
||||
converted_stop_reason = None
|
||||
async for chunk_str in vllm_result:
|
||||
# Due to OpenAI compatibility, each event in the stream will start with "data: " and
|
||||
# end with "\n\n".
|
||||
_prefix = "data: "
|
||||
_suffix = "\n\n"
|
||||
if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix):
|
||||
raise ValueError(f"Can't parse result string from vLLM: '{re.escape(chunk_str)}'")
|
||||
|
||||
# In between the "data: " and newlines is an event record
|
||||
data_str = chunk_str[len(_prefix) : -len(_suffix)]
|
||||
|
||||
# The end of the stream is indicated with "[DONE]"
|
||||
if data_str == "[DONE]":
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=converted_stop_reason,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Anything that is not "[DONE]" should be a JSON record
|
||||
parsed_chunk = json.loads(data_str)
|
||||
|
||||
logger.debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
|
||||
|
||||
# The result may contain multiple completions, but Llama Stack APIs only support
|
||||
# returning one.
|
||||
first_choice = parsed_chunk["choices"][0]
|
||||
converted_stop_reason = get_stop_reason(first_choice["finish_reason"])
|
||||
delta_record = first_choice["delta"]
|
||||
|
||||
if "content" in delta_record:
|
||||
# Text delta
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=TextDelta(text=delta_record["content"]),
|
||||
stop_reason=converted_stop_reason,
|
||||
)
|
||||
)
|
||||
elif "tool_calls" in delta_record:
|
||||
# Tool call(s). Llama Stack APIs do not have a clear way to return partial tool
|
||||
# calls, so buffer until we get a "tool calls" stop reason
|
||||
for tc in delta_record["tool_calls"]:
|
||||
index = tc["index"]
|
||||
if index not in index_to_tool_call:
|
||||
# First time this tool call is showing up
|
||||
index_to_tool_call[index] = dict()
|
||||
tool_call = index_to_tool_call[index]
|
||||
if "id" in tc:
|
||||
tool_call["call_id"] = tc["id"]
|
||||
if "function" in tc:
|
||||
if "name" in tc["function"]:
|
||||
tool_call["tool_name"] = tc["function"]["name"]
|
||||
if "arguments" in tc["function"]:
|
||||
# Arguments comes in as pieces of a string
|
||||
if "arguments_str" not in tool_call:
|
||||
tool_call["arguments_str"] = ""
|
||||
tool_call["arguments_str"] += tc["function"]["arguments"]
|
||||
else:
|
||||
raise ValueError(f"Don't know how to parse event delta: {delta_record}")
|
||||
|
||||
if first_choice["finish_reason"] == "tool_calls":
|
||||
# Special OpenAI code for "tool calls complete".
|
||||
# Output the buffered tool calls. Llama Stack requires a separate event per tool
|
||||
# call.
|
||||
for tool_call_record in index_to_tool_call.values():
|
||||
# Arguments come in as a string. Parse the completed string.
|
||||
tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"])
|
||||
del tool_call_record["arguments_str"]
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(tool_call=tool_call_record, parse_status="succeeded"),
|
||||
stop_reason=converted_stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# If we get here, we've lost the connection with the vLLM event stream before it ended
|
||||
# normally.
|
||||
raise ValueError("vLLM event stream ended without [DONE] message.")
|
||||
|
|
@ -319,7 +319,7 @@ class HFFinetuningSingleDevice:
|
|||
use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False,
|
||||
save_strategy=save_strategy,
|
||||
report_to="none",
|
||||
max_seq_length=provider_config.max_seq_length,
|
||||
max_length=provider_config.max_seq_length,
|
||||
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
||||
gradient_checkpointing=provider_config.gradient_checkpointing,
|
||||
learning_rate=lr,
|
||||
|
|
|
|||
|
|
@ -146,9 +146,9 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
# Allow any model to be registered as a shield
|
||||
# The model will be validated during runtime when making inference calls
|
||||
pass
|
||||
model_id = shield.provider_resource_id
|
||||
if not model_id:
|
||||
raise ValueError("Llama Guard shield must have a model id")
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -11,19 +11,9 @@ from opentelemetry.sdk.trace import ReadableSpan
|
|||
from opentelemetry.sdk.trace.export import SpanProcessor
|
||||
from opentelemetry.trace.status import StatusCode
|
||||
|
||||
# Colors for console output
|
||||
COLORS = {
|
||||
"reset": "\033[0m",
|
||||
"bold": "\033[1m",
|
||||
"dim": "\033[2m",
|
||||
"red": "\033[31m",
|
||||
"green": "\033[32m",
|
||||
"yellow": "\033[33m",
|
||||
"blue": "\033[34m",
|
||||
"magenta": "\033[35m",
|
||||
"cyan": "\033[36m",
|
||||
"white": "\033[37m",
|
||||
}
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name="console_span_processor", category="telemetry")
|
||||
|
||||
|
||||
class ConsoleSpanProcessor(SpanProcessor):
|
||||
|
|
@ -35,34 +25,21 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
return
|
||||
|
||||
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
print(
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
f"{COLORS['magenta']}[START]{COLORS['reset']} "
|
||||
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
|
||||
)
|
||||
logger.info(f"[dim]{timestamp}[/dim] [bold magenta][START][/bold magenta] [dim]{span.name}[/dim]")
|
||||
|
||||
def on_end(self, span: ReadableSpan) -> None:
|
||||
if span.attributes and span.attributes.get("__autotraced__"):
|
||||
return
|
||||
|
||||
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
span_context = (
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
f"{COLORS['magenta']}[END]{COLORS['reset']} "
|
||||
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
|
||||
)
|
||||
|
||||
span_context = f"[dim]{timestamp}[/dim] [bold magenta][END][/bold magenta] [dim]{span.name}[/dim]"
|
||||
if span.status.status_code == StatusCode.ERROR:
|
||||
span_context += f"{COLORS['reset']} {COLORS['red']}[ERROR]{COLORS['reset']}"
|
||||
span_context += " [bold red][ERROR][/bold red]"
|
||||
elif span.status.status_code != StatusCode.UNSET:
|
||||
span_context += f"{COLORS['reset']} [{span.status.status_code}]"
|
||||
|
||||
span_context += f" [{span.status.status_code}]"
|
||||
duration_ms = (span.end_time - span.start_time) / 1e6
|
||||
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)"
|
||||
|
||||
print(span_context)
|
||||
span_context += f" ({duration_ms:.2f}ms)"
|
||||
logger.info(span_context)
|
||||
|
||||
if self.print_attributes and span.attributes:
|
||||
for key, value in span.attributes.items():
|
||||
|
|
@ -71,31 +48,26 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
str_value = str(value)
|
||||
if len(str_value) > 1000:
|
||||
str_value = str_value[:997] + "..."
|
||||
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}")
|
||||
logger.info(f" [dim]{key}[/dim]: {str_value}")
|
||||
|
||||
for event in span.events:
|
||||
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
severity = event.attributes.get("severity", "info")
|
||||
message = event.attributes.get("message", event.name)
|
||||
if isinstance(message, dict | list):
|
||||
if isinstance(message, dict) or isinstance(message, list):
|
||||
message = json.dumps(message, indent=2)
|
||||
|
||||
severity_colors = {
|
||||
"error": f"{COLORS['bold']}{COLORS['red']}",
|
||||
"warn": f"{COLORS['bold']}{COLORS['yellow']}",
|
||||
"info": COLORS["white"],
|
||||
"debug": COLORS["dim"],
|
||||
}
|
||||
msg_color = severity_colors.get(severity, COLORS["white"])
|
||||
|
||||
print(f" {event_time} {msg_color}[{severity.upper()}] {message}{COLORS['reset']}")
|
||||
|
||||
severity_color = {
|
||||
"error": "red",
|
||||
"warn": "yellow",
|
||||
"info": "white",
|
||||
"debug": "dim",
|
||||
}.get(severity, "white")
|
||||
logger.info(f" {event_time} [bold {severity_color}][{severity.upper()}][/bold {severity_color}] {message}")
|
||||
if event.attributes:
|
||||
for key, value in event.attributes.items():
|
||||
if key.startswith("__") or key in ["message", "severity"]:
|
||||
continue
|
||||
print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}")
|
||||
logger.info(f"/r[dim]{key}[/dim]: {value}")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the processor."""
|
||||
|
|
|
|||
|
|
@ -16,6 +16,6 @@ async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]):
|
|||
ChromaVectorIOAdapter,
|
||||
)
|
||||
|
||||
impl = ChromaVectorIOAdapter(config, deps[Api.inference])
|
||||
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -6,12 +6,25 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ChromaVectorIOConfig(BaseModel):
|
||||
db_path: str
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any) -> dict[str, Any]:
|
||||
return {"db_path": db_path}
|
||||
def sample_run_config(
|
||||
cls, __distro_dir__: str, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"db_path": db_path,
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="chroma_inline_registry.db",
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,6 +55,11 @@ class FaissIndex(EmbeddingIndex):
|
|||
self.kvstore = kvstore
|
||||
self.bank_id = bank_id
|
||||
|
||||
# A list of chunk id's in the same order as they are in the index,
|
||||
# must be updated when chunks are added or removed
|
||||
self.chunk_id_lock = asyncio.Lock()
|
||||
self.chunk_ids: list[Any] = []
|
||||
|
||||
@classmethod
|
||||
async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
||||
instance = cls(dimension, kvstore, bank_id)
|
||||
|
|
@ -75,6 +80,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
|
||||
try:
|
||||
self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False))
|
||||
self.chunk_ids = [chunk.chunk_id for chunk in self.chunk_by_index.values()]
|
||||
except Exception as e:
|
||||
logger.debug(e, exc_info=True)
|
||||
raise ValueError(
|
||||
|
|
@ -114,11 +120,33 @@ class FaissIndex(EmbeddingIndex):
|
|||
for i, chunk in enumerate(chunks):
|
||||
self.chunk_by_index[indexlen + i] = chunk
|
||||
|
||||
self.index.add(np.array(embeddings).astype(np.float32))
|
||||
async with self.chunk_id_lock:
|
||||
self.index.add(np.array(embeddings).astype(np.float32))
|
||||
self.chunk_ids.extend([chunk.chunk_id for chunk in chunks])
|
||||
|
||||
# Save updated index
|
||||
await self._save_index()
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
if chunk_id not in self.chunk_ids:
|
||||
return
|
||||
|
||||
async with self.chunk_id_lock:
|
||||
index = self.chunk_ids.index(chunk_id)
|
||||
self.index.remove_ids(np.array([index]))
|
||||
|
||||
new_chunk_by_index = {}
|
||||
for idx, chunk in self.chunk_by_index.items():
|
||||
# Shift all chunks after the removed chunk to the left
|
||||
if idx > index:
|
||||
new_chunk_by_index[idx - 1] = chunk
|
||||
else:
|
||||
new_chunk_by_index[idx] = chunk
|
||||
self.chunk_by_index = new_chunk_by_index
|
||||
self.chunk_ids.pop(index)
|
||||
|
||||
await self._save_index()
|
||||
|
||||
async def query_vector(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
|
|
@ -261,47 +289,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def _save_openai_vector_store_file(
|
||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Save vector store file data to kvstore."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(file_info))
|
||||
content_key = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}"
|
||||
await self.kvstore.set(key=content_key, value=json.dumps(file_contents))
|
||||
|
||||
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
|
||||
"""Load vector store file metadata from kvstore."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
|
||||
stored_data = await self.kvstore.get(key)
|
||||
return json.loads(stored_data) if stored_data else {}
|
||||
|
||||
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
||||
"""Load vector store file contents from kvstore."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}"
|
||||
stored_data = await self.kvstore.get(key)
|
||||
return json.loads(stored_data) if stored_data else []
|
||||
|
||||
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
||||
"""Update vector store file metadata in kvstore."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(file_info))
|
||||
|
||||
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||
"""Delete vector store data from kvstore."""
|
||||
assert self.kvstore is not None
|
||||
|
||||
keys_to_delete = [
|
||||
f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}",
|
||||
f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}",
|
||||
]
|
||||
for key in keys_to_delete:
|
||||
try:
|
||||
await self.kvstore.delete(key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete key {key}: {e}")
|
||||
continue
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
"""Delete a chunk from a faiss index"""
|
||||
faiss_index = self.cache[store_id].index
|
||||
for chunk_id in chunk_ids:
|
||||
await faiss_index.delete_chunk(chunk_id)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import sqlite3
|
||||
|
|
@ -426,6 +425,35 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
"""Remove a chunk from the SQLite vector store."""
|
||||
|
||||
def _delete_chunk():
|
||||
connection = _create_sqlite_connection(self.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute("BEGIN TRANSACTION")
|
||||
|
||||
# Delete from metadata table
|
||||
cur.execute(f"DELETE FROM {self.metadata_table} WHERE id = ?", (chunk_id,))
|
||||
|
||||
# Delete from vector table
|
||||
cur.execute(f"DELETE FROM {self.vector_table} WHERE id = ?", (chunk_id,))
|
||||
|
||||
# Delete from FTS table
|
||||
cur.execute(f"DELETE FROM {self.fts_table} WHERE id = ?", (chunk_id,))
|
||||
|
||||
connection.commit()
|
||||
except Exception as e:
|
||||
connection.rollback()
|
||||
logger.error(f"Error deleting chunk {chunk_id}: {e}")
|
||||
raise
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_delete_chunk)
|
||||
|
||||
|
||||
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
"""
|
||||
|
|
@ -506,140 +534,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
async def _save_openai_vector_store_file(
|
||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Save vector store file metadata to SQLite database."""
|
||||
|
||||
def _create_or_store():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
# Create a table to persist OpenAI vector store files.
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS openai_vector_store_files (
|
||||
store_id TEXT,
|
||||
file_id TEXT,
|
||||
metadata TEXT,
|
||||
PRIMARY KEY (store_id, file_id)
|
||||
);
|
||||
""")
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS openai_vector_store_files_contents (
|
||||
store_id TEXT,
|
||||
file_id TEXT,
|
||||
contents TEXT,
|
||||
PRIMARY KEY (store_id, file_id)
|
||||
);
|
||||
""")
|
||||
connection.commit()
|
||||
cur.execute(
|
||||
"INSERT OR REPLACE INTO openai_vector_store_files (store_id, file_id, metadata) VALUES (?, ?, ?)",
|
||||
(store_id, file_id, json.dumps(file_info)),
|
||||
)
|
||||
cur.execute(
|
||||
"INSERT OR REPLACE INTO openai_vector_store_files_contents (store_id, file_id, contents) VALUES (?, ?, ?)",
|
||||
(store_id, file_id, json.dumps(file_contents)),
|
||||
)
|
||||
connection.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving openai vector store file {store_id} {file_id}: {e}")
|
||||
raise
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_create_or_store)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving openai vector store file {store_id} {file_id}: {e}")
|
||||
raise
|
||||
|
||||
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
|
||||
"""Load vector store file metadata from SQLite database."""
|
||||
|
||||
def _load():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT metadata FROM openai_vector_store_files WHERE store_id = ? AND file_id = ?",
|
||||
(store_id, file_id),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
(metadata,) = row
|
||||
return metadata
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
stored_data = await asyncio.to_thread(_load)
|
||||
return json.loads(stored_data) if stored_data else {}
|
||||
|
||||
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
||||
"""Load vector store file contents from SQLite database."""
|
||||
|
||||
def _load():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT contents FROM openai_vector_store_files_contents WHERE store_id = ? AND file_id = ?",
|
||||
(store_id, file_id),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
(contents,) = row
|
||||
return contents
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
stored_contents = await asyncio.to_thread(_load)
|
||||
return json.loads(stored_contents) if stored_contents else []
|
||||
|
||||
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
||||
"""Update vector store file metadata in SQLite database."""
|
||||
|
||||
def _update():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"UPDATE openai_vector_store_files SET metadata = ? WHERE store_id = ? AND file_id = ?",
|
||||
(json.dumps(file_info), store_id, file_id),
|
||||
)
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_update)
|
||||
|
||||
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||
"""Delete vector store file metadata from SQLite database."""
|
||||
|
||||
def _delete():
|
||||
connection = _create_sqlite_connection(self.config.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"DELETE FROM openai_vector_store_files WHERE store_id = ? AND file_id = ?", (store_id, file_id)
|
||||
)
|
||||
cur.execute(
|
||||
"DELETE FROM openai_vector_store_files_contents WHERE store_id = ? AND file_id = ?",
|
||||
(store_id, file_id),
|
||||
)
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_delete)
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
|
|
@ -655,3 +549,13 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
"""Delete a chunk from a sqlite_vec index."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {store_id} not found")
|
||||
|
||||
for chunk_id in chunk_ids:
|
||||
# Use the index's delete_chunk method
|
||||
await index.index.delete_chunk(chunk_id)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue