mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 16:22:37 +00:00
Merge branch 'llamastack:main' into model_unregisteration_error_message
This commit is contained in:
commit
1180626a22
103 changed files with 11265 additions and 704 deletions
|
|
@ -797,7 +797,7 @@ class Agents(Protocol):
|
|||
self,
|
||||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
"""Retrieve an OpenAI response by its ID.
|
||||
"""Get a model response.
|
||||
|
||||
:param response_id: The ID of the OpenAI response to retrieve.
|
||||
:returns: An OpenAIResponseObject.
|
||||
|
|
@ -826,7 +826,7 @@ class Agents(Protocol):
|
|||
),
|
||||
] = None,
|
||||
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create a new OpenAI response.
|
||||
"""Create a model response.
|
||||
|
||||
:param input: Input message(s) to create the response.
|
||||
:param model: The underlying LLM used for completions.
|
||||
|
|
@ -846,7 +846,7 @@ class Agents(Protocol):
|
|||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
"""List all OpenAI responses.
|
||||
"""List all responses.
|
||||
|
||||
:param after: The ID of the last response to return.
|
||||
:param limit: The number of responses to return.
|
||||
|
|
@ -869,7 +869,7 @@ class Agents(Protocol):
|
|||
limit: int | None = 20,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseInputItem:
|
||||
"""List input items for a given OpenAI response.
|
||||
"""List input items.
|
||||
|
||||
:param response_id: The ID of the response to retrieve input items for.
|
||||
:param after: An item ID to list items after, used for pagination.
|
||||
|
|
@ -884,7 +884,7 @@ class Agents(Protocol):
|
|||
@webmethod(route="/openai/v1/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
"""Delete an OpenAI response by its ID.
|
||||
"""Delete a response.
|
||||
|
||||
:param response_id: The ID of the OpenAI response to delete.
|
||||
:returns: An OpenAIDeleteResponseObject
|
||||
|
|
|
|||
|
|
@ -104,6 +104,11 @@ class OpenAIFileDeleteResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Files(Protocol):
|
||||
"""Files
|
||||
|
||||
This API is used to upload documents that can be used with other Llama Stack APIs.
|
||||
"""
|
||||
|
||||
# OpenAI Files API Endpoints
|
||||
@webmethod(route="/openai/v1/files", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/files", method="POST", level=LLAMA_STACK_API_V1)
|
||||
|
|
@ -113,7 +118,8 @@ class Files(Protocol):
|
|||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after: Annotated[ExpiresAfter | None, Form()] = None,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
"""Upload file.
|
||||
|
||||
Upload a file that can be used across various endpoints.
|
||||
|
||||
The file upload should be a multipart form request with:
|
||||
|
|
@ -137,7 +143,8 @@ class Files(Protocol):
|
|||
order: Order | None = Order.desc,
|
||||
purpose: OpenAIFilePurpose | None = None,
|
||||
) -> ListOpenAIFileResponse:
|
||||
"""
|
||||
"""List files.
|
||||
|
||||
Returns a list of files that belong to the user's organization.
|
||||
|
||||
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
|
||||
|
|
@ -154,7 +161,8 @@ class Files(Protocol):
|
|||
self,
|
||||
file_id: str,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
"""Retrieve file.
|
||||
|
||||
Returns information about a specific file.
|
||||
|
||||
:param file_id: The ID of the file to use for this request.
|
||||
|
|
@ -168,8 +176,7 @@ class Files(Protocol):
|
|||
self,
|
||||
file_id: str,
|
||||
) -> OpenAIFileDeleteResponse:
|
||||
"""
|
||||
Delete a file.
|
||||
"""Delete file.
|
||||
|
||||
:param file_id: The ID of the file to use for this request.
|
||||
:returns: An OpenAIFileDeleteResponse indicating successful deletion.
|
||||
|
|
@ -182,7 +189,8 @@ class Files(Protocol):
|
|||
self,
|
||||
file_id: str,
|
||||
) -> Response:
|
||||
"""
|
||||
"""Retrieve file content.
|
||||
|
||||
Returns the contents of the specified file.
|
||||
|
||||
:param file_id: The ID of the file to use for this request.
|
||||
|
|
|
|||
|
|
@ -1053,7 +1053,9 @@ class InferenceProvider(Protocol):
|
|||
# for fill-in-the-middle type completion
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||
"""Create completion.
|
||||
|
||||
Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param prompt: The prompt to generate a completion for.
|
||||
|
|
@ -1105,7 +1107,9 @@ class InferenceProvider(Protocol):
|
|||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||
"""Create chat completions.
|
||||
|
||||
Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param messages: List of messages in the conversation.
|
||||
|
|
@ -1144,7 +1148,9 @@ class InferenceProvider(Protocol):
|
|||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""Generate OpenAI-compatible embeddings for the given input using the specified model.
|
||||
"""Create embeddings.
|
||||
|
||||
Generate OpenAI-compatible embeddings for the given input using the specified model.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
|
||||
:param input: Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings.
|
||||
|
|
@ -1157,7 +1163,9 @@ class InferenceProvider(Protocol):
|
|||
|
||||
|
||||
class Inference(InferenceProvider):
|
||||
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||
"""Inference
|
||||
|
||||
Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||
|
||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||
|
|
@ -1173,7 +1181,7 @@ class Inference(InferenceProvider):
|
|||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIChatCompletionResponse:
|
||||
"""List all chat completions.
|
||||
"""List chat completions.
|
||||
|
||||
:param after: The ID of the last chat completion to return.
|
||||
:param limit: The maximum number of chat completions to return.
|
||||
|
|
@ -1188,7 +1196,9 @@ class Inference(InferenceProvider):
|
|||
)
|
||||
@webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||
"""Describe a chat completion by its ID.
|
||||
"""Get chat completion.
|
||||
|
||||
Describe a chat completion by its ID.
|
||||
|
||||
:param completion_id: ID of the chat completion.
|
||||
:returns: A OpenAICompletionWithInputMessages.
|
||||
|
|
|
|||
|
|
@ -58,9 +58,16 @@ class ListRoutesResponse(BaseModel):
|
|||
|
||||
@runtime_checkable
|
||||
class Inspect(Protocol):
|
||||
"""Inspect
|
||||
|
||||
APIs for inspecting the Llama Stack service, including health status, available API routes with methods and implementing providers.
|
||||
"""
|
||||
|
||||
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
"""List all available API routes with their methods and implementing providers.
|
||||
"""List routes.
|
||||
|
||||
List all available API routes with their methods and implementing providers.
|
||||
|
||||
:returns: Response containing information about all available routes.
|
||||
"""
|
||||
|
|
@ -68,7 +75,9 @@ class Inspect(Protocol):
|
|||
|
||||
@webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def health(self) -> HealthInfo:
|
||||
"""Get the current health status of the service.
|
||||
"""Get health status.
|
||||
|
||||
Get the current health status of the service.
|
||||
|
||||
:returns: Health information indicating if the service is operational.
|
||||
"""
|
||||
|
|
@ -76,7 +85,9 @@ class Inspect(Protocol):
|
|||
|
||||
@webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def version(self) -> VersionInfo:
|
||||
"""Get the version of the service.
|
||||
"""Get version.
|
||||
|
||||
Get the version of the service.
|
||||
|
||||
:returns: Version information containing the service version number.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -124,7 +124,9 @@ class Models(Protocol):
|
|||
self,
|
||||
model_id: str,
|
||||
) -> Model:
|
||||
"""Get a model by its identifier.
|
||||
"""Get model.
|
||||
|
||||
Get a model by its identifier.
|
||||
|
||||
:param model_id: The identifier of the model to get.
|
||||
:returns: A Model.
|
||||
|
|
@ -140,7 +142,9 @@ class Models(Protocol):
|
|||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
) -> Model:
|
||||
"""Register a model.
|
||||
"""Register model.
|
||||
|
||||
Register a model.
|
||||
|
||||
:param model_id: The identifier of the model to register.
|
||||
:param provider_model_id: The identifier of the model in the provider.
|
||||
|
|
@ -156,7 +160,9 @@ class Models(Protocol):
|
|||
self,
|
||||
model_id: str,
|
||||
) -> None:
|
||||
"""Unregister a model.
|
||||
"""Unregister model.
|
||||
|
||||
Unregister a model.
|
||||
|
||||
:param model_id: The identifier of the model to unregister.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -94,7 +94,9 @@ class ListPromptsResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Prompts(Protocol):
|
||||
"""Protocol for prompt management operations."""
|
||||
"""Prompts
|
||||
|
||||
Protocol for prompt management operations."""
|
||||
|
||||
@webmethod(route="/prompts", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_prompts(self) -> ListPromptsResponse:
|
||||
|
|
@ -109,7 +111,9 @@ class Prompts(Protocol):
|
|||
self,
|
||||
prompt_id: str,
|
||||
) -> ListPromptsResponse:
|
||||
"""List all versions of a specific prompt.
|
||||
"""List prompt versions.
|
||||
|
||||
List all versions of a specific prompt.
|
||||
|
||||
:param prompt_id: The identifier of the prompt to list versions for.
|
||||
:returns: A ListPromptsResponse containing all versions of the prompt.
|
||||
|
|
@ -122,7 +126,9 @@ class Prompts(Protocol):
|
|||
prompt_id: str,
|
||||
version: int | None = None,
|
||||
) -> Prompt:
|
||||
"""Get a prompt by its identifier and optional version.
|
||||
"""Get prompt.
|
||||
|
||||
Get a prompt by its identifier and optional version.
|
||||
|
||||
:param prompt_id: The identifier of the prompt to get.
|
||||
:param version: The version of the prompt to get (defaults to latest).
|
||||
|
|
@ -136,7 +142,9 @@ class Prompts(Protocol):
|
|||
prompt: str,
|
||||
variables: list[str] | None = None,
|
||||
) -> Prompt:
|
||||
"""Create a new prompt.
|
||||
"""Create prompt.
|
||||
|
||||
Create a new prompt.
|
||||
|
||||
:param prompt: The prompt text content with variable placeholders.
|
||||
:param variables: List of variable names that can be used in the prompt template.
|
||||
|
|
@ -153,7 +161,9 @@ class Prompts(Protocol):
|
|||
variables: list[str] | None = None,
|
||||
set_as_default: bool = True,
|
||||
) -> Prompt:
|
||||
"""Update an existing prompt (increments version).
|
||||
"""Update prompt.
|
||||
|
||||
Update an existing prompt (increments version).
|
||||
|
||||
:param prompt_id: The identifier of the prompt to update.
|
||||
:param prompt: The updated prompt text content.
|
||||
|
|
@ -169,7 +179,9 @@ class Prompts(Protocol):
|
|||
self,
|
||||
prompt_id: str,
|
||||
) -> None:
|
||||
"""Delete a prompt.
|
||||
"""Delete prompt.
|
||||
|
||||
Delete a prompt.
|
||||
|
||||
:param prompt_id: The identifier of the prompt to delete.
|
||||
"""
|
||||
|
|
@ -181,7 +193,9 @@ class Prompts(Protocol):
|
|||
prompt_id: str,
|
||||
version: int,
|
||||
) -> Prompt:
|
||||
"""Set which version of a prompt should be the default in get_prompt (latest).
|
||||
"""Set prompt version.
|
||||
|
||||
Set which version of a prompt should be the default in get_prompt (latest).
|
||||
|
||||
:param prompt_id: The identifier of the prompt.
|
||||
:param version: The version to set as default.
|
||||
|
|
|
|||
|
|
@ -42,13 +42,16 @@ class ListProvidersResponse(BaseModel):
|
|||
|
||||
@runtime_checkable
|
||||
class Providers(Protocol):
|
||||
"""
|
||||
"""Providers
|
||||
|
||||
Providers API for inspecting, listing, and modifying providers and their configurations.
|
||||
"""
|
||||
|
||||
@webmethod(route="/providers", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
"""List all available providers.
|
||||
"""List providers.
|
||||
|
||||
List all available providers.
|
||||
|
||||
:returns: A ListProvidersResponse containing information about all providers.
|
||||
"""
|
||||
|
|
@ -56,7 +59,9 @@ class Providers(Protocol):
|
|||
|
||||
@webmethod(route="/providers/{provider_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||
"""Get detailed information about a specific provider.
|
||||
"""Get provider.
|
||||
|
||||
Get detailed information about a specific provider.
|
||||
|
||||
:param provider_id: The ID of the provider to inspect.
|
||||
:returns: A ProviderInfo object containing the provider's details.
|
||||
|
|
|
|||
|
|
@ -96,6 +96,11 @@ class ShieldStore(Protocol):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Safety(Protocol):
|
||||
"""Safety
|
||||
|
||||
OpenAI-compatible Moderations API.
|
||||
"""
|
||||
|
||||
shield_store: ShieldStore
|
||||
|
||||
@webmethod(route="/safety/run-shield", method="POST", level=LLAMA_STACK_API_V1)
|
||||
|
|
@ -105,7 +110,9 @@ class Safety(Protocol):
|
|||
messages: list[Message],
|
||||
params: dict[str, Any],
|
||||
) -> RunShieldResponse:
|
||||
"""Run a shield.
|
||||
"""Run shield.
|
||||
|
||||
Run a shield.
|
||||
|
||||
:param shield_id: The identifier of the shield to run.
|
||||
:param messages: The messages to run the shield on.
|
||||
|
|
@ -117,7 +124,9 @@ class Safety(Protocol):
|
|||
@webmethod(route="/openai/v1/moderations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
"""Classifies if text and/or image inputs are potentially harmful.
|
||||
"""Create moderation.
|
||||
|
||||
Classifies if text and/or image inputs are potentially harmful.
|
||||
:param input: Input (or inputs) to classify.
|
||||
Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
|
||||
:param model: The content moderation model you would like to use.
|
||||
|
|
|
|||
|
|
@ -75,39 +75,6 @@ class StackRun(Subcommand):
|
|||
help="Start the UI server",
|
||||
)
|
||||
|
||||
def _resolve_config_and_distro(self, args: argparse.Namespace) -> tuple[Path | None, str | None]:
|
||||
"""Resolve config file path and distribution name from args.config"""
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
if not args.config:
|
||||
return None, None
|
||||
|
||||
config_file = Path(args.config)
|
||||
has_yaml_suffix = args.config.endswith(".yaml")
|
||||
distro_name = None
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if this is a distribution
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "distributions" / args.config / "run.yaml"
|
||||
if config_file.exists():
|
||||
distro_name = args.config
|
||||
|
||||
if not config_file.exists() and not has_yaml_suffix:
|
||||
# check if it's a build config saved to ~/.llama dir
|
||||
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
|
||||
|
||||
if not config_file.exists():
|
||||
self.parser.error(
|
||||
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
|
||||
)
|
||||
|
||||
if not config_file.is_file():
|
||||
self.parser.error(
|
||||
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
|
||||
)
|
||||
|
||||
return config_file, distro_name
|
||||
|
||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||
import yaml
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
try:
|
||||
models = await provider.list_models()
|
||||
except Exception as e:
|
||||
logger.warning(f"Model refresh failed for provider {provider_id}: {e}")
|
||||
logger.debug(f"Model refresh failed for provider {provider_id}: {e}")
|
||||
continue
|
||||
|
||||
self.listed_providers.add(provider_id)
|
||||
|
|
|
|||
|
|
@ -245,3 +245,65 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_create_vector_store_file_batch(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_ids: list[str],
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: Any | None = None,
|
||||
):
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_create_vector_store_file_batch(
|
||||
vector_store_id=vector_store_id,
|
||||
file_ids=file_ids,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
):
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: str | None = None,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
):
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
after=after,
|
||||
before=before,
|
||||
filter=filter,
|
||||
limit=limit,
|
||||
order=order,
|
||||
)
|
||||
|
||||
async def openai_cancel_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
):
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_cancel_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from pathlib import Path
|
|||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, "tokenizer_utils")
|
||||
logger = get_logger(__name__, "models")
|
||||
|
||||
|
||||
def load_bpe_file(model_path: Path) -> dict[bytes, int]:
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
|
|||
deps[Api.tool_runtime],
|
||||
deps[Api.tool_groups],
|
||||
policy,
|
||||
Api.telemetry in deps,
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -110,6 +110,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
persistence_store: KVStore,
|
||||
created_at: str,
|
||||
policy: list[AccessRule],
|
||||
telemetry_enabled: bool = False,
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
self.agent_config = agent_config
|
||||
|
|
@ -120,6 +121,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.created_at = created_at
|
||||
self.telemetry_enabled = telemetry_enabled
|
||||
|
||||
ShieldRunnerMixin.__init__(
|
||||
self,
|
||||
|
|
@ -188,28 +190,30 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||
turn_id = str(uuid.uuid4())
|
||||
span = tracing.get_current_span()
|
||||
if span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
if self.telemetry_enabled:
|
||||
span = tracing.get_current_span()
|
||||
if span is not None:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
|
||||
await self._initialize_tools(request.toolgroups)
|
||||
async for chunk in self._run_turn(request, turn_id):
|
||||
yield chunk
|
||||
|
||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||
span = tracing.get_current_span()
|
||||
if span:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
if self.telemetry_enabled:
|
||||
span = tracing.get_current_span()
|
||||
if span is not None:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
|
||||
await self._initialize_tools()
|
||||
async for chunk in self._run_turn(request):
|
||||
|
|
@ -395,9 +399,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
async with tracing.span("run_shields") as span:
|
||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||
if self.telemetry_enabled and span is not None:
|
||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||
if len(shields) == 0:
|
||||
span.set_attribute("output", "no shields")
|
||||
|
||||
if len(shields) == 0:
|
||||
span.set_attribute("output", "no shields")
|
||||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
|
|
@ -430,7 +437,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
span.set_attribute("output", e.violation.model_dump_json())
|
||||
if self.telemetry_enabled and span is not None:
|
||||
span.set_attribute("output", e.violation.model_dump_json())
|
||||
|
||||
yield CompletionMessage(
|
||||
content=str(e),
|
||||
|
|
@ -453,7 +461,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
span.set_attribute("output", "no violations")
|
||||
if self.telemetry_enabled and span is not None:
|
||||
span.set_attribute("output", "no violations")
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
|
|
@ -518,8 +527,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
stop_reason: StopReason | None = None
|
||||
|
||||
async with tracing.span("inference") as span:
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
if self.telemetry_enabled and span is not None:
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
|
||||
def _serialize_nested(value):
|
||||
"""Recursively serialize nested Pydantic models to dicts."""
|
||||
|
|
@ -637,18 +647,19 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
raise ValueError(f"Unexpected delta type {type(delta)}")
|
||||
|
||||
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
|
||||
span.set_attribute(
|
||||
"input",
|
||||
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
||||
)
|
||||
output_attr = json.dumps(
|
||||
{
|
||||
"content": content,
|
||||
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
|
||||
}
|
||||
)
|
||||
span.set_attribute("output", output_attr)
|
||||
if self.telemetry_enabled and span is not None:
|
||||
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
|
||||
span.set_attribute(
|
||||
"input",
|
||||
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
||||
)
|
||||
output_attr = json.dumps(
|
||||
{
|
||||
"content": content,
|
||||
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
|
||||
}
|
||||
)
|
||||
span.set_attribute("output", output_attr)
|
||||
|
||||
n_iter += 1
|
||||
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
|
||||
|
|
@ -756,7 +767,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
{
|
||||
"tool_name": tool_call.tool_name,
|
||||
"input": message.model_dump_json(),
|
||||
},
|
||||
}
|
||||
if self.telemetry_enabled
|
||||
else {},
|
||||
) as span:
|
||||
tool_execution_start_time = datetime.now(UTC).isoformat()
|
||||
tool_result = await self.execute_tool_call_maybe(
|
||||
|
|
@ -771,7 +784,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
call_id=tool_call.call_id,
|
||||
content=tool_result.content,
|
||||
)
|
||||
span.set_attribute("output", result_message.model_dump_json())
|
||||
if self.telemetry_enabled and span is not None:
|
||||
span.set_attribute("output", result_message.model_dump_json())
|
||||
|
||||
# Store tool execution step
|
||||
tool_execution_step = ToolExecutionStep(
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
policy: list[AccessRule],
|
||||
telemetry_enabled: bool = False,
|
||||
):
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
|
|
@ -71,6 +72,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.safety_api = safety_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.telemetry_enabled = telemetry_enabled
|
||||
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
self.openai_responses_impl: OpenAIResponsesImpl | None = None
|
||||
|
|
@ -135,6 +137,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
),
|
||||
created_at=agent_info.created_at,
|
||||
policy=self.policy,
|
||||
telemetry_enabled=self.telemetry_enabled,
|
||||
)
|
||||
|
||||
async def create_agent_session(
|
||||
|
|
|
|||
|
|
@ -97,6 +97,8 @@ class StreamingResponseOrchestrator:
|
|||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||
# Track final messages after all tool executions
|
||||
self.final_messages: list[OpenAIMessageParam] = []
|
||||
# mapping for annotations
|
||||
self.citation_files: dict[str, str] = {}
|
||||
|
||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Initialize output messages
|
||||
|
|
@ -126,6 +128,7 @@ class StreamingResponseOrchestrator:
|
|||
# Text is the default response format for chat completion so don't need to pass it
|
||||
# (some providers don't support non-empty response_format when tools are present)
|
||||
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
|
||||
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
|
||||
completion_result = await self.inference_api.openai_chat_completion(
|
||||
model=self.ctx.model,
|
||||
messages=messages,
|
||||
|
|
@ -160,7 +163,7 @@ class StreamingResponseOrchestrator:
|
|||
# Handle choices with no tool calls
|
||||
for choice in current_response.choices:
|
||||
if not (choice.message.tool_calls and self.ctx.response_tools):
|
||||
output_messages.append(await convert_chat_choice_to_response_message(choice))
|
||||
output_messages.append(await convert_chat_choice_to_response_message(choice, self.citation_files))
|
||||
|
||||
# Execute tool calls and coordinate results
|
||||
async for stream_event in self._coordinate_tool_execution(
|
||||
|
|
@ -211,6 +214,8 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
for choice in current_response.choices:
|
||||
next_turn_messages.append(choice.message)
|
||||
logger.debug(f"Choice message content: {choice.message.content}")
|
||||
logger.debug(f"Choice message tool_calls: {choice.message.tool_calls}")
|
||||
|
||||
if choice.message.tool_calls and self.ctx.response_tools:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
|
|
@ -470,6 +475,8 @@ class StreamingResponseOrchestrator:
|
|||
tool_call_log = result.final_output_message
|
||||
tool_response_message = result.final_input_message
|
||||
self.sequence_number = result.sequence_number
|
||||
if result.citation_files:
|
||||
self.citation_files.update(result.citation_files)
|
||||
|
||||
if tool_call_log:
|
||||
output_messages.append(tool_call_log)
|
||||
|
|
|
|||
|
|
@ -94,7 +94,10 @@ class ToolExecutor:
|
|||
|
||||
# Yield the final result
|
||||
yield ToolExecutionResult(
|
||||
sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message
|
||||
sequence_number=sequence_number,
|
||||
final_output_message=output_message,
|
||||
final_input_message=input_message,
|
||||
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
|
||||
)
|
||||
|
||||
async def _execute_knowledge_search_via_vector_store(
|
||||
|
|
@ -129,8 +132,6 @@ class ToolExecutor:
|
|||
for results in all_results:
|
||||
search_results.extend(results)
|
||||
|
||||
# Convert search results to tool result format matching memory.py
|
||||
# Format the results as interleaved content similar to memory.py
|
||||
content_items = []
|
||||
content_items.append(
|
||||
TextContentItem(
|
||||
|
|
@ -138,27 +139,58 @@ class ToolExecutor:
|
|||
)
|
||||
)
|
||||
|
||||
unique_files = set()
|
||||
for i, result_item in enumerate(search_results):
|
||||
chunk_text = result_item.content[0].text if result_item.content else ""
|
||||
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
|
||||
# Get file_id from attributes if result_item.file_id is empty
|
||||
file_id = result_item.file_id or (
|
||||
result_item.attributes.get("document_id") if result_item.attributes else None
|
||||
)
|
||||
metadata_text = f"document_id: {file_id}, score: {result_item.score}"
|
||||
if result_item.attributes:
|
||||
metadata_text += f", attributes: {result_item.attributes}"
|
||||
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
|
||||
|
||||
text_content = f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n"
|
||||
content_items.append(TextContentItem(text=text_content))
|
||||
unique_files.add(file_id)
|
||||
|
||||
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||
|
||||
citation_instruction = ""
|
||||
if unique_files:
|
||||
citation_instruction = (
|
||||
" Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format (e.g., 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'). "
|
||||
"Do not add extra punctuation. Use only the file IDs provided (do not invent new ones)."
|
||||
)
|
||||
|
||||
content_items.append(
|
||||
TextContentItem(
|
||||
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
|
||||
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.{citation_instruction}\n',
|
||||
)
|
||||
)
|
||||
|
||||
# handling missing attributes for old versions
|
||||
citation_files = {}
|
||||
for result in search_results:
|
||||
file_id = result.file_id
|
||||
if not file_id and result.attributes:
|
||||
file_id = result.attributes.get("document_id")
|
||||
|
||||
filename = result.filename
|
||||
if not filename and result.attributes:
|
||||
filename = result.attributes.get("filename")
|
||||
if not filename:
|
||||
filename = "unknown"
|
||||
|
||||
citation_files[file_id] = filename
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=content_items,
|
||||
metadata={
|
||||
"document_ids": [r.file_id for r in search_results],
|
||||
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
||||
"scores": [r.score for r in search_results],
|
||||
"citation_files": citation_files,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ class ToolExecutionResult(BaseModel):
|
|||
sequence_number: int
|
||||
final_output_message: OpenAIResponseOutput | None = None
|
||||
final_input_message: OpenAIMessageParam | None = None
|
||||
citation_files: dict[str, str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -4,9 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseAnnotationFileCitation,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContent,
|
||||
|
|
@ -45,7 +47,9 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
|
||||
|
||||
async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
|
||||
async def convert_chat_choice_to_response_message(
|
||||
choice: OpenAIChoice, citation_files: dict[str, str] | None = None
|
||||
) -> OpenAIResponseMessage:
|
||||
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
|
||||
output_content = ""
|
||||
if isinstance(choice.message.content, str):
|
||||
|
|
@ -57,9 +61,11 @@ async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenA
|
|||
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
|
||||
)
|
||||
|
||||
annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {})
|
||||
|
||||
return OpenAIResponseMessage(
|
||||
id=f"msg_{uuid.uuid4()}",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
|
|
@ -200,6 +206,53 @@ async def get_message_type_by_role(role: str):
|
|||
return role_to_type.get(role)
|
||||
|
||||
|
||||
def _extract_citations_from_text(
|
||||
text: str, citation_files: dict[str, str]
|
||||
) -> tuple[list[OpenAIResponseAnnotationFileCitation], str]:
|
||||
"""Extract citation markers from text and create annotations
|
||||
|
||||
Args:
|
||||
text: The text containing citation markers like [file-Cn3MSNn72ENTiiq11Qda4A]
|
||||
citation_files: Dictionary mapping file_id to filename
|
||||
|
||||
Returns:
|
||||
Tuple of (annotations_list, clean_text_without_markers)
|
||||
"""
|
||||
file_id_regex = re.compile(r"<\|(?P<file_id>file-[A-Za-z0-9_-]+)\|>")
|
||||
|
||||
annotations = []
|
||||
parts = []
|
||||
total_len = 0
|
||||
last_end = 0
|
||||
|
||||
for m in file_id_regex.finditer(text):
|
||||
# segment before the marker
|
||||
prefix = text[last_end : m.start()]
|
||||
|
||||
# drop one space if it exists (since marker is at sentence end)
|
||||
if prefix.endswith(" "):
|
||||
prefix = prefix[:-1]
|
||||
|
||||
parts.append(prefix)
|
||||
total_len += len(prefix)
|
||||
|
||||
fid = m.group(1)
|
||||
if fid in citation_files:
|
||||
annotations.append(
|
||||
OpenAIResponseAnnotationFileCitation(
|
||||
file_id=fid,
|
||||
filename=citation_files[fid],
|
||||
index=total_len, # index points to punctuation
|
||||
)
|
||||
)
|
||||
|
||||
last_end = m.end()
|
||||
|
||||
parts.append(text[last_end:])
|
||||
cleaned_text = "".join(parts)
|
||||
return annotations, cleaned_text
|
||||
|
||||
|
||||
def is_function_tool_call(
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
tools: list[OpenAIResponseInputTool],
|
||||
|
|
|
|||
|
|
@ -331,5 +331,8 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
|
||||
return ToolInvocationResult(
|
||||
content=result.content or [],
|
||||
metadata=result.metadata,
|
||||
metadata={
|
||||
**(result.metadata or {}),
|
||||
"citation_files": getattr(result, "citation_files", None),
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -200,12 +200,10 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.kvstore: KVStore | None = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
|
|
|
|||
|
|
@ -410,12 +410,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
"""
|
||||
|
||||
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.kvstore: KVStore | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
|
|
|
|||
|
|
@ -36,6 +36,9 @@ def available_providers() -> list[ProviderSpec]:
|
|||
Api.tool_runtime,
|
||||
Api.tool_groups,
|
||||
],
|
||||
optional_api_dependencies=[
|
||||
Api.telemetry,
|
||||
],
|
||||
description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.",
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from llama_stack.providers.datatypes import (
|
|||
ProviderSpec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
from llama_stack.providers.registry.vector_io import DEFAULT_VECTOR_IO_DEPS
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
|
|
@ -18,9 +19,8 @@ def available_providers() -> list[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
provider_type="inline::rag-runtime",
|
||||
pip_packages=[
|
||||
"chardet",
|
||||
"pypdf",
|
||||
pip_packages=DEFAULT_VECTOR_IO_DEPS
|
||||
+ [
|
||||
"tqdm",
|
||||
"numpy",
|
||||
"scikit-learn",
|
||||
|
|
|
|||
|
|
@ -12,13 +12,16 @@ from llama_stack.providers.datatypes import (
|
|||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
# Common dependencies for all vector IO providers that support document processing
|
||||
DEFAULT_VECTOR_IO_DEPS = ["chardet", "pypdf"]
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::meta-reference",
|
||||
pip_packages=["faiss-cpu"],
|
||||
pip_packages=["faiss-cpu"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.inline.vector_io.faiss",
|
||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||
|
|
@ -29,7 +32,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::faiss",
|
||||
pip_packages=["faiss-cpu"],
|
||||
pip_packages=["faiss-cpu"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.inline.vector_io.faiss",
|
||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
@ -82,7 +85,7 @@ more details about Faiss in general.
|
|||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::sqlite-vec",
|
||||
pip_packages=["sqlite-vec"],
|
||||
pip_packages=["sqlite-vec"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
||||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
@ -289,7 +292,7 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
|
|||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::sqlite_vec",
|
||||
pip_packages=["sqlite-vec"],
|
||||
pip_packages=["sqlite-vec"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
||||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
|
||||
|
|
@ -303,7 +306,7 @@ Please refer to the sqlite-vec provider documentation.
|
|||
api=Api.vector_io,
|
||||
adapter_type="chromadb",
|
||||
provider_type="remote::chromadb",
|
||||
pip_packages=["chromadb-client"],
|
||||
pip_packages=["chromadb-client"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.remote.vector_io.chroma",
|
||||
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
@ -345,7 +348,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
|||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::chromadb",
|
||||
pip_packages=["chromadb"],
|
||||
pip_packages=["chromadb"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.inline.vector_io.chroma",
|
||||
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
@ -389,7 +392,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
|||
api=Api.vector_io,
|
||||
adapter_type="pgvector",
|
||||
provider_type="remote::pgvector",
|
||||
pip_packages=["psycopg2-binary"],
|
||||
pip_packages=["psycopg2-binary"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.remote.vector_io.pgvector",
|
||||
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
@ -500,7 +503,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
|
|||
api=Api.vector_io,
|
||||
adapter_type="weaviate",
|
||||
provider_type="remote::weaviate",
|
||||
pip_packages=["weaviate-client>=4.16.5"],
|
||||
pip_packages=["weaviate-client>=4.16.5"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||
|
|
@ -541,7 +544,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
|
|||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::qdrant",
|
||||
pip_packages=["qdrant-client"],
|
||||
pip_packages=["qdrant-client"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.inline.vector_io.qdrant",
|
||||
config_class="llama_stack.providers.inline.vector_io.qdrant.QdrantVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
@ -594,7 +597,7 @@ See the [Qdrant documentation](https://qdrant.tech/documentation/) for more deta
|
|||
api=Api.vector_io,
|
||||
adapter_type="qdrant",
|
||||
provider_type="remote::qdrant",
|
||||
pip_packages=["qdrant-client"],
|
||||
pip_packages=["qdrant-client"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.remote.vector_io.qdrant",
|
||||
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
@ -607,7 +610,7 @@ Please refer to the inline provider documentation.
|
|||
api=Api.vector_io,
|
||||
adapter_type="milvus",
|
||||
provider_type="remote::milvus",
|
||||
pip_packages=["pymilvus>=2.4.10"],
|
||||
pip_packages=["pymilvus>=2.4.10"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.remote.vector_io.milvus",
|
||||
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
@ -813,7 +816,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
|
|||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::milvus",
|
||||
pip_packages=["pymilvus[milvus-lite]>=2.4.10"],
|
||||
pip_packages=["pymilvus[milvus-lite]>=2.4.10"] + DEFAULT_VECTOR_IO_DEPS,
|
||||
module="llama_stack.providers.inline.vector_io.milvus",
|
||||
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
|
|||
|
|
@ -41,9 +41,6 @@ class DatabricksInferenceAdapter(OpenAIMixin):
|
|||
).serving_endpoints.list() # TODO: this is not async
|
||||
]
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference.inference import OpenAICompletion
|
||||
from llama_stack.apis.inference.inference import OpenAICompletion, OpenAIEmbeddingsResponse
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
|
@ -56,3 +56,13 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
|
|||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
|
|
@ -15,10 +13,6 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
|||
|
||||
class OllamaImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = DEFAULT_OLLAMA_URL
|
||||
refresh_models: bool = Field(
|
||||
default=False,
|
||||
description="Whether to refresh models periodically",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -72,9 +72,6 @@ class OllamaInferenceAdapter(OpenAIMixin):
|
|||
f"Ollama Server is not running (message: {r['message']}). Make sure to start it using `ollama serve` in a separate terminal"
|
||||
)
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return self.config.refresh_models
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
Performs a health check by verifying connectivity to the Ollama server.
|
||||
|
|
|
|||
|
|
@ -11,6 +11,6 @@ async def get_adapter_impl(config: RunpodImplConfig, _deps):
|
|||
from .runpod import RunpodInferenceAdapter
|
||||
|
||||
assert isinstance(config, RunpodImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = RunpodInferenceAdapter(config)
|
||||
impl = RunpodInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -4,69 +4,86 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
|
||||
|
||||
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import RunpodImplConfig
|
||||
|
||||
# https://docs.runpod.io/serverless/vllm/overview#compatible-models
|
||||
# https://github.com/runpod-workers/worker-vllm/blob/main/README.md#compatible-model-architectures
|
||||
RUNPOD_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
|
||||
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
|
||||
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
|
||||
"Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8",
|
||||
"Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B",
|
||||
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
|
||||
"Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8",
|
||||
"Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
"Llama3.2-1B": "meta-llama/Llama-3.2-1B",
|
||||
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
|
||||
}
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
class RunpodInferenceAdapter(OpenAIMixin):
|
||||
"""
|
||||
Adapter for RunPod's OpenAI-compatible API endpoints.
|
||||
Supports VLLM for serverless endpoint self-hosted or public endpoints.
|
||||
Can work with any runpod endpoints that support OpenAI-compatible API
|
||||
"""
|
||||
|
||||
# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(provider_model_id, model_descriptor)
|
||||
for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
config: RunpodImplConfig
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
"""Get API key for OpenAI client."""
|
||||
return self.config.api_token
|
||||
|
||||
class RunpodInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
):
|
||||
def __init__(self, config: RunpodImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
|
||||
self.config = config
|
||||
def get_base_url(self) -> str:
|
||||
"""Get base URL for OpenAI client."""
|
||||
return self.config.url
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": chat_completion_request_to_prompt(request),
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def openai_embeddings(
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
):
|
||||
"""Override to add RunPod-specific stream_options requirement."""
|
||||
if stream and not stream_options:
|
||||
stream_options = {"include_usage": True}
|
||||
|
||||
return await super().openai_chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -63,12 +63,6 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
|
|||
# Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client
|
||||
return [m.id for m in await self._get_client().models.list()]
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return True
|
||||
|
||||
async def check_model_availability(self, model):
|
||||
return model in self._model_cache
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
|||
|
|
@ -30,10 +30,6 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
|||
default=True,
|
||||
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
|
||||
)
|
||||
refresh_models: bool = Field(
|
||||
default=False,
|
||||
description="Whether to refresh models periodically",
|
||||
)
|
||||
|
||||
@field_validator("tls_verify")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ from llama_stack.apis.inference import (
|
|||
OpenAIResponseFormatParam,
|
||||
ToolChoice,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
|
|
@ -54,25 +53,6 @@ class VLLMInferenceAdapter(OpenAIMixin):
|
|||
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
|
||||
)
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
# Strictly respecting the refresh_models directive
|
||||
return self.config.refresh_models
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
models = []
|
||||
async for m in self.client.models.list():
|
||||
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
|
||||
models.append(
|
||||
Model(
|
||||
identifier=m.id,
|
||||
provider_resource_id=m.id,
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
metadata={},
|
||||
model_type=model_type,
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
Performs a health check by verifying connectivity to the remote vLLM server.
|
||||
|
|
|
|||
|
|
@ -140,14 +140,13 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
inference_api: Api.inference,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.client = None
|
||||
self.cache = {}
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
self.files_api = files_api
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
|
|
|
|||
|
|
@ -309,14 +309,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
inference_api: Inference,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.cache = {}
|
||||
self.client = None
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
|||
|
|
@ -345,14 +345,12 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
inference_api: Api.inference,
|
||||
files_api: Files | None = None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.conn = None
|
||||
self.cache = {}
|
||||
self.files_api = files_api
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from llama_stack.apis.vector_io import (
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
|
|
@ -162,14 +162,12 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
inference_api: Api.inference,
|
||||
files_api: Files | None = None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.client: AsyncQdrantClient = None
|
||||
self.cache = {}
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
self.vector_db_store = None
|
||||
self.kvstore: KVStore | None = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self._qdrant_lock = asyncio.Lock()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
|||
|
|
@ -284,14 +284,12 @@ class WeaviateVectorIOAdapter(
|
|||
inference_api: Api.inference,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.client_cache = {}
|
||||
self.cache = {}
|
||||
self.files_api = files_api
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
def _get_client(self) -> weaviate.WeaviateClient:
|
||||
|
|
|
|||
|
|
@ -24,6 +24,10 @@ class RemoteInferenceProviderConfig(BaseModel):
|
|||
default=None,
|
||||
description="List of models that should be registered with the model registry. If None, all models are allowed.",
|
||||
)
|
||||
refresh_models: bool = Field(
|
||||
default=False,
|
||||
description="Whether to refresh models periodically from the provider",
|
||||
)
|
||||
|
||||
|
||||
# TODO: this class is more confusing than useful right now. We need to make it
|
||||
|
|
|
|||
|
|
@ -132,7 +132,10 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
|
||||
:return: An iterable of model IDs or None if not implemented
|
||||
"""
|
||||
return [m.id async for m in self.client.models.list()]
|
||||
client = self.client
|
||||
async with client:
|
||||
model_ids = [m.id async for m in client.models.list()]
|
||||
return model_ids
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
|
|
@ -481,7 +484,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
return model in self._model_cache
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
return self.config.refresh_models
|
||||
|
||||
#
|
||||
# The model_dump implementations are to avoid serializing the extra fields,
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ import uuid
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files import Files, OpenAIFileObject
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
|
|
@ -50,12 +52,16 @@ logger = get_logger(name=__name__, category="providers::utils")
|
|||
|
||||
# Constants for OpenAI vector stores
|
||||
CHUNK_MULTIPLIER = 5
|
||||
FILE_BATCH_CLEANUP_INTERVAL_SECONDS = 24 * 60 * 60 # 1 day in seconds
|
||||
MAX_CONCURRENT_FILES_PER_BATCH = 3 # Maximum concurrent file processing within a batch
|
||||
FILE_BATCH_CHUNK_SIZE = 10 # Process files in chunks of this size
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX = f"openai_vector_stores_file_batches:{VERSION}::"
|
||||
|
||||
|
||||
class OpenAIVectorStoreMixin(ABC):
|
||||
|
|
@ -65,11 +71,15 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
an openai_vector_stores in-memory cache.
|
||||
"""
|
||||
|
||||
# These should be provided by the implementing class
|
||||
openai_vector_stores: dict[str, dict[str, Any]]
|
||||
files_api: Files | None
|
||||
# KV store for persisting OpenAI vector store metadata
|
||||
kvstore: KVStore | None
|
||||
# Implementing classes should call super().__init__() in their __init__ method
|
||||
# to properly initialize the mixin attributes.
|
||||
def __init__(self, files_api: Files | None = None, kvstore: KVStore | None = None):
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.openai_file_batches: dict[str, dict[str, Any]] = {}
|
||||
self.files_api = files_api
|
||||
self.kvstore = kvstore
|
||||
self._last_file_batch_cleanup_time = 0
|
||||
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
|
||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Save vector store metadata to persistent storage."""
|
||||
|
|
@ -159,9 +169,129 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
for idx in range(len(raw_items)):
|
||||
await self.kvstore.delete(f"{contents_prefix}{idx}")
|
||||
|
||||
async def _save_openai_vector_store_file_batch(self, batch_id: str, batch_info: dict[str, Any]) -> None:
|
||||
"""Save file batch metadata to persistent storage."""
|
||||
assert self.kvstore
|
||||
key = f"{OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX}{batch_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(batch_info))
|
||||
# update in-memory cache
|
||||
self.openai_file_batches[batch_id] = batch_info
|
||||
|
||||
async def _load_openai_vector_store_file_batches(self) -> dict[str, dict[str, Any]]:
|
||||
"""Load all file batch metadata from persistent storage."""
|
||||
assert self.kvstore
|
||||
start_key = OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX
|
||||
end_key = f"{OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX}\xff"
|
||||
stored_data = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
batches: dict[str, dict[str, Any]] = {}
|
||||
for item in stored_data:
|
||||
info = json.loads(item)
|
||||
batches[info["id"]] = info
|
||||
return batches
|
||||
|
||||
async def _delete_openai_vector_store_file_batch(self, batch_id: str) -> None:
|
||||
"""Delete file batch metadata from persistent storage and in-memory cache."""
|
||||
assert self.kvstore
|
||||
key = f"{OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX}{batch_id}"
|
||||
await self.kvstore.delete(key)
|
||||
# remove from in-memory cache
|
||||
self.openai_file_batches.pop(batch_id, None)
|
||||
|
||||
async def _cleanup_expired_file_batches(self) -> None:
|
||||
"""Clean up expired file batches from persistent storage."""
|
||||
assert self.kvstore
|
||||
start_key = OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX
|
||||
end_key = f"{OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX}\xff"
|
||||
stored_data = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
current_time = int(time.time())
|
||||
expired_count = 0
|
||||
|
||||
for item in stored_data:
|
||||
info = json.loads(item)
|
||||
expires_at = info.get("expires_at")
|
||||
if expires_at and current_time > expires_at:
|
||||
logger.info(f"Cleaning up expired file batch: {info['id']}")
|
||||
await self.kvstore.delete(f"{OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX}{info['id']}")
|
||||
# Remove from in-memory cache if present
|
||||
self.openai_file_batches.pop(info["id"], None)
|
||||
expired_count += 1
|
||||
|
||||
if expired_count > 0:
|
||||
logger.info(f"Cleaned up {expired_count} expired file batches")
|
||||
|
||||
async def _get_completed_files_in_batch(self, vector_store_id: str, file_ids: list[str]) -> set[str]:
|
||||
"""Determine which files in a batch are actually completed by checking vector store file_ids."""
|
||||
if vector_store_id not in self.openai_vector_stores:
|
||||
return set()
|
||||
|
||||
store_info = self.openai_vector_stores[vector_store_id]
|
||||
completed_files = set(file_ids) & set(store_info["file_ids"])
|
||||
return completed_files
|
||||
|
||||
async def _analyze_batch_completion_on_resume(self, batch_id: str, batch_info: dict[str, Any]) -> list[str]:
|
||||
"""Analyze batch completion status and return remaining files to process.
|
||||
|
||||
Returns:
|
||||
List of file IDs that still need processing. Empty list if batch is complete.
|
||||
"""
|
||||
vector_store_id = batch_info["vector_store_id"]
|
||||
all_file_ids = batch_info["file_ids"]
|
||||
|
||||
# Find files that are actually completed
|
||||
completed_files = await self._get_completed_files_in_batch(vector_store_id, all_file_ids)
|
||||
remaining_files = [file_id for file_id in all_file_ids if file_id not in completed_files]
|
||||
|
||||
completed_count = len(completed_files)
|
||||
total_count = len(all_file_ids)
|
||||
remaining_count = len(remaining_files)
|
||||
|
||||
# Update file counts to reflect actual state
|
||||
batch_info["file_counts"] = {
|
||||
"completed": completed_count,
|
||||
"failed": 0, # We don't track failed files during resume - they'll be retried
|
||||
"in_progress": remaining_count,
|
||||
"cancelled": 0,
|
||||
"total": total_count,
|
||||
}
|
||||
|
||||
# If all files are already completed, mark batch as completed
|
||||
if remaining_count == 0:
|
||||
batch_info["status"] = "completed"
|
||||
logger.info(f"Batch {batch_id} is already fully completed, updating status")
|
||||
|
||||
# Save updated batch info
|
||||
await self._save_openai_vector_store_file_batch(batch_id, batch_info)
|
||||
|
||||
return remaining_files
|
||||
|
||||
async def _resume_incomplete_batches(self) -> None:
|
||||
"""Resume processing of incomplete file batches after server restart."""
|
||||
for batch_id, batch_info in self.openai_file_batches.items():
|
||||
if batch_info["status"] == "in_progress":
|
||||
logger.info(f"Analyzing incomplete file batch: {batch_id}")
|
||||
|
||||
remaining_files = await self._analyze_batch_completion_on_resume(batch_id, batch_info)
|
||||
|
||||
# Check if batch is now completed after analysis
|
||||
if batch_info["status"] == "completed":
|
||||
continue
|
||||
|
||||
if remaining_files:
|
||||
logger.info(f"Resuming batch {batch_id} with {len(remaining_files)} remaining files")
|
||||
# Restart the background processing task with only remaining files
|
||||
task = asyncio.create_task(self._process_file_batch_async(batch_id, batch_info, remaining_files))
|
||||
self._file_batch_tasks[batch_id] = task
|
||||
|
||||
async def initialize_openai_vector_stores(self) -> None:
|
||||
"""Load existing OpenAI vector stores into the in-memory cache."""
|
||||
"""Load existing OpenAI vector stores and file batches into the in-memory cache."""
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
self.openai_file_batches = await self._load_openai_vector_store_file_batches()
|
||||
self._file_batch_tasks = {}
|
||||
# TODO: Resume only works for single worker deployment. Jobs with multiple workers will need to be handled differently.
|
||||
await self._resume_incomplete_batches()
|
||||
self._last_file_batch_cleanup_time = 0
|
||||
|
||||
@abstractmethod
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
|
|
@ -457,7 +587,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
content = self._chunk_to_vector_store_content(chunk)
|
||||
|
||||
response_data_item = VectorStoreSearchResponse(
|
||||
file_id=chunk.metadata.get("file_id", ""),
|
||||
file_id=chunk.metadata.get("document_id", ""),
|
||||
filename=chunk.metadata.get("filename", ""),
|
||||
score=score,
|
||||
attributes=chunk.metadata,
|
||||
|
|
@ -570,6 +700,14 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
if vector_store_id not in self.openai_vector_stores:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
# Check if file is already attached to this vector store
|
||||
store_info = self.openai_vector_stores[vector_store_id]
|
||||
if file_id in store_info["file_ids"]:
|
||||
logger.warning(f"File {file_id} is already attached to vector store {vector_store_id}, skipping")
|
||||
# Return existing file object
|
||||
file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
|
||||
return VectorStoreFileObject(**file_info)
|
||||
|
||||
attributes = attributes or {}
|
||||
chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto()
|
||||
created_at = int(time.time())
|
||||
|
|
@ -608,14 +746,16 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
|
||||
content = content_from_data_and_mime_type(content_response.body, mime_type)
|
||||
|
||||
chunk_attributes = attributes.copy()
|
||||
chunk_attributes["filename"] = file_response.filename
|
||||
|
||||
chunks = make_overlapped_chunks(
|
||||
file_id,
|
||||
content,
|
||||
max_chunk_size_tokens,
|
||||
chunk_overlap_tokens,
|
||||
attributes,
|
||||
chunk_attributes,
|
||||
)
|
||||
|
||||
if not chunks:
|
||||
vector_store_file_object.status = "failed"
|
||||
vector_store_file_object.last_error = VectorStoreFileLastError(
|
||||
|
|
@ -828,7 +968,230 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
"""Create a vector store file batch."""
|
||||
raise NotImplementedError("openai_create_vector_store_file_batch is not implemented yet")
|
||||
if vector_store_id not in self.openai_vector_stores:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto()
|
||||
|
||||
created_at = int(time.time())
|
||||
batch_id = f"batch_{uuid.uuid4()}"
|
||||
# File batches expire after 7 days
|
||||
expires_at = created_at + (7 * 24 * 60 * 60)
|
||||
|
||||
# Initialize batch file counts - all files start as in_progress
|
||||
file_counts = VectorStoreFileCounts(
|
||||
completed=0,
|
||||
cancelled=0,
|
||||
failed=0,
|
||||
in_progress=len(file_ids),
|
||||
total=len(file_ids),
|
||||
)
|
||||
|
||||
# Create batch object immediately with in_progress status
|
||||
batch_object = VectorStoreFileBatchObject(
|
||||
id=batch_id,
|
||||
created_at=created_at,
|
||||
vector_store_id=vector_store_id,
|
||||
status="in_progress",
|
||||
file_counts=file_counts,
|
||||
)
|
||||
|
||||
batch_info = {
|
||||
**batch_object.model_dump(),
|
||||
"file_ids": file_ids,
|
||||
"attributes": attributes,
|
||||
"chunking_strategy": chunking_strategy.model_dump(),
|
||||
"expires_at": expires_at,
|
||||
}
|
||||
await self._save_openai_vector_store_file_batch(batch_id, batch_info)
|
||||
|
||||
# Start background processing of files
|
||||
task = asyncio.create_task(self._process_file_batch_async(batch_id, batch_info))
|
||||
self._file_batch_tasks[batch_id] = task
|
||||
|
||||
# Run cleanup if needed (throttled to once every 1 day)
|
||||
current_time = int(time.time())
|
||||
if current_time - self._last_file_batch_cleanup_time >= FILE_BATCH_CLEANUP_INTERVAL_SECONDS:
|
||||
logger.info("Running throttled cleanup of expired file batches")
|
||||
asyncio.create_task(self._cleanup_expired_file_batches())
|
||||
self._last_file_batch_cleanup_time = current_time
|
||||
|
||||
return batch_object
|
||||
|
||||
async def _process_files_with_concurrency(
|
||||
self,
|
||||
file_ids: list[str],
|
||||
vector_store_id: str,
|
||||
attributes: dict[str, Any],
|
||||
chunking_strategy_obj: Any,
|
||||
batch_id: str,
|
||||
batch_info: dict[str, Any],
|
||||
) -> None:
|
||||
"""Process files with controlled concurrency and chunking."""
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_FILES_PER_BATCH)
|
||||
|
||||
async def process_single_file(file_id: str) -> tuple[str, bool]:
|
||||
"""Process a single file with concurrency control."""
|
||||
async with semaphore:
|
||||
try:
|
||||
vector_store_file_object = await self.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy_obj,
|
||||
)
|
||||
return file_id, vector_store_file_object.status == "completed"
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process file {file_id} in batch {batch_id}: {e}")
|
||||
return file_id, False
|
||||
|
||||
# Process files in chunks to avoid creating too many tasks at once
|
||||
total_files = len(file_ids)
|
||||
for chunk_start in range(0, total_files, FILE_BATCH_CHUNK_SIZE):
|
||||
chunk_end = min(chunk_start + FILE_BATCH_CHUNK_SIZE, total_files)
|
||||
chunk = file_ids[chunk_start:chunk_end]
|
||||
|
||||
chunk_num = chunk_start // FILE_BATCH_CHUNK_SIZE + 1
|
||||
total_chunks = (total_files + FILE_BATCH_CHUNK_SIZE - 1) // FILE_BATCH_CHUNK_SIZE
|
||||
logger.info(
|
||||
f"Processing chunk {chunk_num} of {total_chunks} ({len(chunk)} files, {chunk_start + 1}-{chunk_end} of {total_files} total files)"
|
||||
)
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
chunk_tasks = [tg.create_task(process_single_file(file_id)) for file_id in chunk]
|
||||
|
||||
chunk_results = [task.result() for task in chunk_tasks]
|
||||
|
||||
# Update counts after each chunk for progressive feedback
|
||||
for _, success in chunk_results:
|
||||
self._update_file_counts(batch_info, success=success)
|
||||
|
||||
# Save progress after each chunk
|
||||
await self._save_openai_vector_store_file_batch(batch_id, batch_info)
|
||||
|
||||
def _update_file_counts(self, batch_info: dict[str, Any], success: bool) -> None:
|
||||
"""Update file counts based on processing result."""
|
||||
if success:
|
||||
batch_info["file_counts"]["completed"] += 1
|
||||
else:
|
||||
batch_info["file_counts"]["failed"] += 1
|
||||
batch_info["file_counts"]["in_progress"] -= 1
|
||||
|
||||
def _update_batch_status(self, batch_info: dict[str, Any]) -> None:
|
||||
"""Update final batch status based on file processing results."""
|
||||
if batch_info["file_counts"]["failed"] == 0:
|
||||
batch_info["status"] = "completed"
|
||||
elif batch_info["file_counts"]["completed"] == 0:
|
||||
batch_info["status"] = "failed"
|
||||
else:
|
||||
batch_info["status"] = "completed" # Partial success counts as completed
|
||||
|
||||
async def _process_file_batch_async(
|
||||
self,
|
||||
batch_id: str,
|
||||
batch_info: dict[str, Any],
|
||||
override_file_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Process files in a batch asynchronously in the background."""
|
||||
file_ids = override_file_ids if override_file_ids is not None else batch_info["file_ids"]
|
||||
attributes = batch_info["attributes"]
|
||||
chunking_strategy = batch_info["chunking_strategy"]
|
||||
vector_store_id = batch_info["vector_store_id"]
|
||||
chunking_strategy_adapter: TypeAdapter[VectorStoreChunkingStrategy] = TypeAdapter(VectorStoreChunkingStrategy)
|
||||
chunking_strategy_obj = chunking_strategy_adapter.validate_python(chunking_strategy)
|
||||
|
||||
try:
|
||||
# Process all files with controlled concurrency
|
||||
await self._process_files_with_concurrency(
|
||||
file_ids=file_ids,
|
||||
vector_store_id=vector_store_id,
|
||||
attributes=attributes,
|
||||
chunking_strategy_obj=chunking_strategy_obj,
|
||||
batch_id=batch_id,
|
||||
batch_info=batch_info,
|
||||
)
|
||||
|
||||
# Update final batch status
|
||||
self._update_batch_status(batch_info)
|
||||
await self._save_openai_vector_store_file_batch(batch_id, batch_info)
|
||||
|
||||
logger.info(f"File batch {batch_id} processing completed with status: {batch_info['status']}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"File batch {batch_id} processing was cancelled")
|
||||
# Clean up task reference if it still exists
|
||||
self._file_batch_tasks.pop(batch_id, None)
|
||||
raise # Re-raise to ensure proper cancellation propagation
|
||||
finally:
|
||||
# Always clean up task reference when processing ends
|
||||
self._file_batch_tasks.pop(batch_id, None)
|
||||
|
||||
def _get_and_validate_batch(self, batch_id: str, vector_store_id: str) -> dict[str, Any]:
|
||||
"""Get and validate batch exists and belongs to vector store."""
|
||||
if vector_store_id not in self.openai_vector_stores:
|
||||
raise VectorStoreNotFoundError(vector_store_id)
|
||||
|
||||
if batch_id not in self.openai_file_batches:
|
||||
raise ValueError(f"File batch {batch_id} not found")
|
||||
|
||||
batch_info = self.openai_file_batches[batch_id]
|
||||
|
||||
# Check if batch has expired (read-only check)
|
||||
expires_at = batch_info.get("expires_at")
|
||||
if expires_at:
|
||||
current_time = int(time.time())
|
||||
if current_time > expires_at:
|
||||
raise ValueError(f"File batch {batch_id} has expired after 7 days from creation")
|
||||
|
||||
if batch_info["vector_store_id"] != vector_store_id:
|
||||
raise ValueError(f"File batch {batch_id} does not belong to vector store {vector_store_id}")
|
||||
|
||||
return batch_info
|
||||
|
||||
def _paginate_objects(
|
||||
self,
|
||||
objects: list[Any],
|
||||
limit: int | None = 20,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
) -> tuple[list[Any], bool, str | None, str | None]:
|
||||
"""Apply pagination to a list of objects with id fields."""
|
||||
limit = min(limit or 20, 100) # Cap at 100 as per OpenAI
|
||||
|
||||
# Find start index
|
||||
start_idx = 0
|
||||
if after:
|
||||
for i, obj in enumerate(objects):
|
||||
if obj.id == after:
|
||||
start_idx = i + 1
|
||||
break
|
||||
|
||||
# Find end index
|
||||
end_idx = start_idx + limit
|
||||
if before:
|
||||
for i, obj in enumerate(objects[start_idx:], start_idx):
|
||||
if obj.id == before:
|
||||
end_idx = i
|
||||
break
|
||||
|
||||
# Apply pagination
|
||||
paginated_objects = objects[start_idx:end_idx]
|
||||
|
||||
# Determine pagination info
|
||||
has_more = end_idx < len(objects)
|
||||
first_id = paginated_objects[0].id if paginated_objects else None
|
||||
last_id = paginated_objects[-1].id if paginated_objects else None
|
||||
|
||||
return paginated_objects, has_more, first_id, last_id
|
||||
|
||||
async def openai_retrieve_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
"""Retrieve a vector store file batch."""
|
||||
batch_info = self._get_and_validate_batch(batch_id, vector_store_id)
|
||||
return VectorStoreFileBatchObject(**batch_info)
|
||||
|
||||
async def openai_list_files_in_vector_store_file_batch(
|
||||
self,
|
||||
|
|
@ -841,15 +1204,39 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
order: str | None = "desc",
|
||||
) -> VectorStoreFilesListInBatchResponse:
|
||||
"""Returns a list of vector store files in a batch."""
|
||||
raise NotImplementedError("openai_list_files_in_vector_store_file_batch is not implemented yet")
|
||||
batch_info = self._get_and_validate_batch(batch_id, vector_store_id)
|
||||
batch_file_ids = batch_info["file_ids"]
|
||||
|
||||
async def openai_retrieve_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
"""Retrieve a vector store file batch."""
|
||||
raise NotImplementedError("openai_retrieve_vector_store_file_batch is not implemented yet")
|
||||
# Load file objects for files in this batch
|
||||
batch_file_objects = []
|
||||
|
||||
for file_id in batch_file_ids:
|
||||
try:
|
||||
file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
|
||||
file_object = VectorStoreFileObject(**file_info)
|
||||
|
||||
# Apply status filter if provided
|
||||
if filter and file_object.status != filter:
|
||||
continue
|
||||
|
||||
batch_file_objects.append(file_object)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load file {file_id} from batch {batch_id}: {e}")
|
||||
continue
|
||||
|
||||
# Sort by created_at
|
||||
reverse_order = order == "desc"
|
||||
batch_file_objects.sort(key=lambda x: x.created_at, reverse=reverse_order)
|
||||
|
||||
# Apply pagination using helper
|
||||
paginated_files, has_more, first_id, last_id = self._paginate_objects(batch_file_objects, limit, after, before)
|
||||
|
||||
return VectorStoreFilesListInBatchResponse(
|
||||
data=paginated_files,
|
||||
first_id=first_id,
|
||||
last_id=last_id,
|
||||
has_more=has_more,
|
||||
)
|
||||
|
||||
async def openai_cancel_vector_store_file_batch(
|
||||
self,
|
||||
|
|
@ -857,4 +1244,24 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
vector_store_id: str,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
"""Cancel a vector store file batch."""
|
||||
raise NotImplementedError("openai_cancel_vector_store_file_batch is not implemented yet")
|
||||
batch_info = self._get_and_validate_batch(batch_id, vector_store_id)
|
||||
|
||||
if batch_info["status"] not in ["in_progress"]:
|
||||
raise ValueError(f"Cannot cancel batch {batch_id} with status {batch_info['status']}")
|
||||
|
||||
# Cancel the actual processing task if it exists
|
||||
if batch_id in self._file_batch_tasks:
|
||||
task = self._file_batch_tasks[batch_id]
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
logger.info(f"Cancelled processing task for file batch: {batch_id}")
|
||||
# Remove from task tracking
|
||||
del self._file_batch_tasks[batch_id]
|
||||
|
||||
batch_info["status"] = "cancelled"
|
||||
|
||||
await self._save_openai_vector_store_file_batch(batch_id, batch_info)
|
||||
|
||||
updated_batch = VectorStoreFileBatchObject(**batch_info)
|
||||
|
||||
return updated_batch
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue