mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: File search tool for Responses API (#2426)
# What does this PR do? This is an initial working prototype of wiring up the `file_search` builtin tool for the Responses API to our existing rag knowledge search tool. This is me seeing what I could pull together on top of the bits we already have merged. This may not be the ideal way to implement this, and things like how I shuffle the vector store ids from the original response API tool request to the actual tool execution feel a bit hacky (grep for `tool_kwargs["vector_db_ids"]` in `_execute_tool_call` to see what I mean). ## Test Plan I stubbed in some new tests to exercise this using text and pdf documents. Note that this is currently under tests/verification only because it sometimes flakes with tool calling of the small Llama-3.2-3B model we run in CI (and that I use as an example below). We'd want to make the test a bit more robust in some way if we moved this over to tests/integration and ran it in CI. ### OpenAI SaaS (to verify test correctness) ``` pytest -sv tests/verifications/openai_api/test_responses.py \ -k 'file_search' \ --base-url=https://api.openai.com/v1 \ --model=gpt-4o ``` ### Fireworks with faiss vector store ``` llama stack run llama_stack/templates/fireworks/run.yaml pytest -sv tests/verifications/openai_api/test_responses.py \ -k 'file_search' \ --base-url=http://localhost:8321/v1/openai/v1 \ --model=meta-llama/Llama-3.3-70B-Instruct ``` ### Ollama with faiss vector store This sometimes flakes on Ollama because the quantized small model doesn't always choose to call the tool to answer the user's question. But, it often works. ``` ollama run llama3.2:3b INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" \ llama stack run ./llama_stack/templates/ollama/run.yaml \ --image-type venv \ --env OLLAMA_URL="http://0.0.0.0:11434" pytest -sv tests/verifications/openai_api/test_responses.py \ -k'file_search' \ --base-url=http://localhost:8321/v1/openai/v1 \ --model=meta-llama/Llama-3.2-3B-Instruct ``` ### OpenAI provider with sqlite-vec vector store ``` llama stack run ./llama_stack/templates/starter/run.yaml --image-type venv pytest -sv tests/verifications/openai_api/test_responses.py \ -k 'file_search' \ --base-url=http://localhost:8321/v1/openai/v1 \ --model=openai/gpt-4o-mini ``` ### Ensure existing vector store integration tests still pass ``` ollama run llama3.2:3b INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" \ llama stack run ./llama_stack/templates/ollama/run.yaml \ --image-type venv \ --env OLLAMA_URL="http://0.0.0.0:11434" LLAMA_STACK_CONFIG=http://localhost:8321 \ pytest -sv tests/integration/vector_io \ --text-model "meta-llama/Llama-3.2-3B-Instruct" \ --embedding-model=all-MiniLM-L6-v2 ``` --------- Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
554ada57b0
commit
941f505eb0
28 changed files with 1105 additions and 24 deletions
|
@ -24,6 +24,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
|
@ -34,6 +35,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
|
@ -62,7 +64,7 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.tools import RAGQueryConfig, ToolGroups, ToolRuntime
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
|
@ -198,7 +200,8 @@ class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
|||
class ChatCompletionContext(BaseModel):
|
||||
model: str
|
||||
messages: list[OpenAIMessageParam]
|
||||
tools: list[ChatCompletionToolParam] | None = None
|
||||
response_tools: list[OpenAIResponseInputTool] | None = None
|
||||
chat_tools: list[ChatCompletionToolParam] | None = None
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
|
||||
temperature: float | None
|
||||
response_format: OpenAIResponseFormatParam
|
||||
|
@ -388,7 +391,8 @@ class OpenAIResponsesImpl:
|
|||
ctx = ChatCompletionContext(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=chat_tools,
|
||||
response_tools=tools,
|
||||
chat_tools=chat_tools,
|
||||
mcp_tool_to_server=mcp_tool_to_server,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
|
@ -417,7 +421,7 @@ class OpenAIResponsesImpl:
|
|||
completion_result = await self.inference_api.openai_chat_completion(
|
||||
model=ctx.model,
|
||||
messages=messages,
|
||||
tools=ctx.tools,
|
||||
tools=ctx.chat_tools,
|
||||
stream=True,
|
||||
temperature=ctx.temperature,
|
||||
response_format=ctx.response_format,
|
||||
|
@ -606,6 +610,12 @@ class OpenAIResponsesImpl:
|
|||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
chat_tools.append(make_openai_tool(tool_name, tool))
|
||||
elif input_tool.type == "file_search":
|
||||
tool_name = "knowledge_search"
|
||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
chat_tools.append(make_openai_tool(tool_name, tool))
|
||||
elif input_tool.type == "mcp":
|
||||
always_allowed = None
|
||||
never_allowed = None
|
||||
|
@ -667,6 +677,7 @@ class OpenAIResponsesImpl:
|
|||
|
||||
tool_call_id = tool_call.id
|
||||
function = tool_call.function
|
||||
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
|
||||
|
||||
if not function or not tool_call_id or not function.name:
|
||||
return None, None
|
||||
|
@ -680,12 +691,26 @@ class OpenAIResponsesImpl:
|
|||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
tool_name=function.name,
|
||||
kwargs=json.loads(function.arguments) if function.arguments else {},
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
else:
|
||||
if function.name == "knowledge_search":
|
||||
response_file_search_tool = next(
|
||||
t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)
|
||||
)
|
||||
if response_file_search_tool:
|
||||
if response_file_search_tool.filters:
|
||||
logger.warning("Filters are not yet supported for file_search tool")
|
||||
if response_file_search_tool.ranking_options:
|
||||
logger.warning("Ranking options are not yet supported for file_search tool")
|
||||
tool_kwargs["vector_db_ids"] = response_file_search_tool.vector_store_ids
|
||||
tool_kwargs["query_config"] = RAGQueryConfig(
|
||||
mode="vector",
|
||||
max_chunks=response_file_search_tool.max_num_results,
|
||||
)
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=function.name,
|
||||
kwargs=json.loads(function.arguments) if function.arguments else {},
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
error_exc = e
|
||||
|
@ -713,6 +738,27 @@ class OpenAIResponsesImpl:
|
|||
)
|
||||
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
|
||||
message.status = "failed"
|
||||
elif function.name == "knowledge_search":
|
||||
message = OpenAIResponseOutputMessageFileSearchToolCall(
|
||||
id=tool_call_id,
|
||||
queries=[tool_kwargs.get("query", "")],
|
||||
status="completed",
|
||||
)
|
||||
if "document_ids" in result.metadata:
|
||||
message.results = []
|
||||
for i, doc_id in enumerate(result.metadata["document_ids"]):
|
||||
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
|
||||
score = result.metadata["scores"][i] if "scores" in result.metadata else None
|
||||
message.results.append(
|
||||
{
|
||||
"file_id": doc_id,
|
||||
"filename": doc_id,
|
||||
"text": text,
|
||||
"score": score,
|
||||
}
|
||||
)
|
||||
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
|
||||
message.status = "failed"
|
||||
else:
|
||||
raise ValueError(f"Unknown tool {function.name} called")
|
||||
|
||||
|
|
|
@ -170,6 +170,8 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
content=picked,
|
||||
metadata={
|
||||
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
||||
"chunks": [c.content for c in chunks[: len(picked)]],
|
||||
"scores": scores[: len(picked)],
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
@ -16,6 +16,6 @@ async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
|
|||
|
||||
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = FaissVectorIOAdapter(config, deps[Api.inference])
|
||||
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -15,6 +15,7 @@ import faiss
|
|||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
|
@ -132,9 +133,10 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
|
||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.kvstore: KVStore | None = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
|
|
|
@ -15,6 +15,6 @@ async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
|
|||
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||
|
||||
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference])
|
||||
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -17,6 +17,7 @@ import numpy as np
|
|||
import sqlite_vec
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.apis.files.files import Files
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
|
@ -301,9 +302,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
|
||||
"""
|
||||
|
||||
def __init__(self, config, inference_api: Inference) -> None:
|
||||
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
|
@ -32,6 +33,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.vector_io.faiss",
|
||||
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
# NOTE: sqlite-vec cannot be bundled into the container image because it does not have a
|
||||
# source distribution and the wheels are not available for all platforms.
|
||||
|
@ -42,6 +44,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.vector_io.sqlite_vec",
|
||||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
|
@ -51,6 +54,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
|
||||
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
|
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.apis.vector_io.vector_io import VectorStoreChunkingStrategy, VectorStoreFileObject
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
|
@ -241,3 +242,12 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
rewrite_query: bool | None = False,
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
||||
|
|
|
@ -25,6 +25,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.apis.vector_io.vector_io import VectorStoreChunkingStrategy, VectorStoreFileObject
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
|
@ -240,6 +241,15 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
) -> VectorStoreSearchResponsePage:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Milvus")
|
||||
|
||||
|
||||
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
||||
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
||||
|
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.apis.vector_io.vector_io import VectorStoreChunkingStrategy, VectorStoreFileObject
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
|
@ -241,3 +242,12 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
rewrite_query: bool | None = False,
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
|
|
|
@ -5,11 +5,13 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import mimetypes
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
QueryChunksResponse,
|
||||
|
@ -20,6 +22,15 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.apis.vector_io.vector_io import (
|
||||
Chunk,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreChunkingStrategyAuto,
|
||||
VectorStoreChunkingStrategyStatic,
|
||||
VectorStoreFileLastError,
|
||||
VectorStoreFileObject,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -36,6 +47,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
|
||||
# These should be provided by the implementing class
|
||||
openai_vector_stores: dict[str, dict[str, Any]]
|
||||
files_api: Files | None
|
||||
|
||||
@abstractmethod
|
||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
|
@ -67,6 +79,16 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
"""Unregister a vector database (provider-specific implementation)."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""Insert chunks into a vector database (provider-specific implementation)."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def query_chunks(
|
||||
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
|
||||
|
@ -383,3 +405,78 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
if metadata[key] != value:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
attributes = attributes or {}
|
||||
chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto()
|
||||
|
||||
vector_store_file_object = VectorStoreFileObject(
|
||||
id=file_id,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
created_at=int(time.time()),
|
||||
status="in_progress",
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
||||
if not hasattr(self, "files_api") or not self.files_api:
|
||||
vector_store_file_object.status = "failed"
|
||||
vector_store_file_object.last_error = VectorStoreFileLastError(
|
||||
code="server_error",
|
||||
message="Files API is not available",
|
||||
)
|
||||
return vector_store_file_object
|
||||
|
||||
if isinstance(chunking_strategy, VectorStoreChunkingStrategyStatic):
|
||||
max_chunk_size_tokens = chunking_strategy.static.max_chunk_size_tokens
|
||||
chunk_overlap_tokens = chunking_strategy.static.chunk_overlap_tokens
|
||||
else:
|
||||
# Default values from OpenAI API spec
|
||||
max_chunk_size_tokens = 800
|
||||
chunk_overlap_tokens = 400
|
||||
|
||||
try:
|
||||
file_response = await self.files_api.openai_retrieve_file(file_id)
|
||||
mime_type, _ = mimetypes.guess_type(file_response.filename)
|
||||
content_response = await self.files_api.openai_retrieve_file_content(file_id)
|
||||
|
||||
content = content_from_data_and_mime_type(content_response.body, mime_type)
|
||||
|
||||
chunks = make_overlapped_chunks(
|
||||
file_id,
|
||||
content,
|
||||
max_chunk_size_tokens,
|
||||
chunk_overlap_tokens,
|
||||
attributes,
|
||||
)
|
||||
|
||||
if not chunks:
|
||||
vector_store_file_object.status = "failed"
|
||||
vector_store_file_object.last_error = VectorStoreFileLastError(
|
||||
code="server_error",
|
||||
message="No chunks were generated from the file",
|
||||
)
|
||||
return vector_store_file_object
|
||||
|
||||
await self.insert_chunks(
|
||||
vector_db_id=vector_store_id,
|
||||
chunks=chunks,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error attaching file to vector store: {e}")
|
||||
vector_store_file_object.status = "failed"
|
||||
vector_store_file_object.last_error = VectorStoreFileLastError(
|
||||
code="server_error",
|
||||
message=str(e),
|
||||
)
|
||||
return vector_store_file_object
|
||||
|
||||
vector_store_file_object.status = "completed"
|
||||
|
||||
return vector_store_file_object
|
||||
|
|
|
@ -72,16 +72,18 @@ def content_from_data(data_url: str) -> str:
|
|||
data = unquote(data)
|
||||
encoding = parts["encoding"] or "utf-8"
|
||||
data = data.encode(encoding)
|
||||
return content_from_data_and_mime_type(data, parts["mimetype"], parts.get("encoding", None))
|
||||
|
||||
encoding = parts["encoding"]
|
||||
if not encoding:
|
||||
import chardet
|
||||
|
||||
detected = chardet.detect(data)
|
||||
encoding = detected["encoding"]
|
||||
def content_from_data_and_mime_type(data: bytes | str, mime_type: str | None, encoding: str | None = None) -> str:
|
||||
if isinstance(data, bytes):
|
||||
if not encoding:
|
||||
import chardet
|
||||
|
||||
mime_type = parts["mimetype"]
|
||||
mime_category = mime_type.split("/")[0]
|
||||
detected = chardet.detect(data)
|
||||
encoding = detected["encoding"]
|
||||
|
||||
mime_category = mime_type.split("/")[0] if mime_type else None
|
||||
if mime_category == "text":
|
||||
# For text-based files (including CSV, MD)
|
||||
return data.decode(encoding)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue