Merge branch 'main' into agents-openai-migration

This commit is contained in:
Matthew Farrellee 2025-09-30 17:51:45 -04:00
commit 724322eeb2
673 changed files with 164269 additions and 14378 deletions

View file

@ -830,6 +830,8 @@ class ChatAgent(ShieldRunnerMixin):
param_type=param.parameter_type,
description=param.description,
required=param.required,
items=param.items,
title=param.title,
default=param.default,
)
for param in tool_def.parameters
@ -873,6 +875,8 @@ class ChatAgent(ShieldRunnerMixin):
param_type=param.parameter_type,
description=param.description,
required=param.required,
items=param.items,
title=param.title,
default=param.default,
)
for param in tool_def.parameters
@ -952,7 +956,7 @@ async def get_raw_document_text(document: Document) -> str:
DeprecationWarning,
stacklevel=2,
)
elif not (document.mime_type.startswith("text/") or document.mime_type == "application/yaml"):
elif not (document.mime_type.startswith("text/") or document.mime_type in ("application/yaml", "application/json")):
raise ValueError(f"Unexpected document mime type: {document.mime_type}")
if isinstance(document.content, URL):

View file

@ -237,6 +237,7 @@ class OpenAIResponsesImpl:
response_tools=tools,
temperature=temperature,
response_format=response_format,
inputs=input,
)
# Create orchestrator and delegate streaming logic

View file

@ -10,10 +10,12 @@ from typing import Any
from llama_stack.apis.agents.openai_responses import (
AllowedToolsFilter,
ApprovalFilter,
MCPListToolsTool,
OpenAIResponseContentPartOutputText,
OpenAIResponseInputTool,
OpenAIResponseInputToolMCP,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
@ -50,6 +52,36 @@ from .utils import convert_chat_choice_to_response_message, is_function_tool_cal
logger = get_logger(name=__name__, category="agents::meta_reference")
def convert_tooldef_to_chat_tool(tool_def):
"""Convert a ToolDef to OpenAI ChatCompletionToolParam format.
Args:
tool_def: ToolDef from the tools API
Returns:
ChatCompletionToolParam suitable for OpenAI chat completion
"""
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
internal_tool_def = ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
items=param.items,
)
for param in tool_def.parameters
},
)
return convert_tooldef_to_openai_tool(internal_tool_def)
class StreamingResponseOrchestrator:
def __init__(
self,
@ -117,10 +149,17 @@ class StreamingResponseOrchestrator:
raise ValueError("Streaming chunk processor failed to return completion data")
current_response = self._build_chat_completion(completion_result_data)
function_tool_calls, non_function_tool_calls, next_turn_messages = self._separate_tool_calls(
function_tool_calls, non_function_tool_calls, approvals, next_turn_messages = self._separate_tool_calls(
current_response, messages
)
# add any approval requests required
for tool_call in approvals:
async for evt in self._add_mcp_approval_request(
tool_call.function.name, tool_call.function.arguments, output_messages
):
yield evt
# Handle choices with no tool calls
for choice in current_response.choices:
if not (choice.message.tool_calls and self.ctx.response_tools):
@ -164,10 +203,11 @@ class StreamingResponseOrchestrator:
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list]:
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list, list]:
"""Separate tool calls into function and non-function categories."""
function_tool_calls = []
non_function_tool_calls = []
approvals = []
next_turn_messages = messages.copy()
for choice in current_response.choices:
@ -178,9 +218,23 @@ class StreamingResponseOrchestrator:
if is_function_tool_call(tool_call, self.ctx.response_tools):
function_tool_calls.append(tool_call)
else:
non_function_tool_calls.append(tool_call)
if self._approval_required(tool_call.function.name):
approval_response = self.ctx.approval_response(
tool_call.function.name, tool_call.function.arguments
)
if approval_response:
if approval_response.approve:
logger.info(f"Approval granted for {tool_call.id} on {tool_call.function.name}")
non_function_tool_calls.append(tool_call)
else:
logger.info(f"Approval denied for {tool_call.id} on {tool_call.function.name}")
else:
logger.info(f"Requesting approval for {tool_call.id} on {tool_call.function.name}")
approvals.append(tool_call)
else:
non_function_tool_calls.append(tool_call)
return function_tool_calls, non_function_tool_calls, next_turn_messages
return function_tool_calls, non_function_tool_calls, approvals, next_turn_messages
async def _process_streaming_chunks(
self, completion_result, output_messages: list[OpenAIResponseOutput]
@ -556,23 +610,7 @@ class StreamingResponseOrchestrator:
continue
if not always_allowed or t.name in always_allowed:
# Add to chat tools for inference
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
tool_def = ToolDefinition(
tool_name=t.name,
description=t.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in t.parameters
},
)
openai_tool = convert_tooldef_to_openai_tool(tool_def)
openai_tool = convert_tooldef_to_chat_tool(t)
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
self.ctx.chat_tools.append(openai_tool)
@ -632,3 +670,46 @@ class StreamingResponseOrchestrator:
# TODO: Emit mcp_list_tools.failed event if needed
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
raise
def _approval_required(self, tool_name: str) -> bool:
if tool_name not in self.mcp_tool_to_server:
return False
mcp_server = self.mcp_tool_to_server[tool_name]
if mcp_server.require_approval == "always":
return True
if mcp_server.require_approval == "never":
return False
if isinstance(mcp_server, ApprovalFilter):
if tool_name in mcp_server.always:
return True
if tool_name in mcp_server.never:
return False
return True
async def _add_mcp_approval_request(
self, tool_name: str, arguments: str, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
mcp_server = self.mcp_tool_to_server[tool_name]
mcp_approval_request = OpenAIResponseMCPApprovalRequest(
arguments=arguments,
id=f"approval_{uuid.uuid4()}",
name=tool_name,
server_label=mcp_server.server_label,
)
output_messages.append(mcp_approval_request)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=mcp_approval_request,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
item=mcp_approval_request,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)

View file

@ -10,7 +10,10 @@ from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMCPApprovalResponse,
OpenAIResponseObjectStream,
OpenAIResponseOutput,
)
@ -58,3 +61,37 @@ class ChatCompletionContext(BaseModel):
chat_tools: list[ChatCompletionToolParam] | None = None
temperature: float | None
response_format: OpenAIResponseFormatParam
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
def __init__(
self,
model: str,
messages: list[OpenAIMessageParam],
response_tools: list[OpenAIResponseInputTool] | None,
temperature: float | None,
response_format: OpenAIResponseFormatParam,
inputs: list[OpenAIResponseInput] | str,
):
super().__init__(
model=model,
messages=messages,
response_tools=response_tools,
temperature=temperature,
response_format=response_format,
)
if not isinstance(inputs, str):
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
self.approval_responses = {
input.approval_request_id: input for input in inputs if input.type == "mcp_approval_response"
}
def approval_response(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalResponse | None:
request = self._approval_request(tool_name, arguments)
return self.approval_responses.get(request.id, None) if request else None
def _approval_request(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalRequest | None:
for request in self.approval_requests:
if request.name == tool_name and request.arguments == arguments:
return request
return None

View file

@ -13,6 +13,8 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMCPApprovalResponse,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
@ -149,6 +151,11 @@ async def convert_response_input_to_chat_messages(
elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools):
# the tool list will be handled separately
pass
elif isinstance(input_item, OpenAIResponseMCPApprovalRequest) or isinstance(
input_item, OpenAIResponseMCPApprovalResponse
):
# these are handled by the responses impl itself and not pass through to chat completions
pass
else:
content = await convert_response_content_to_chat_content(input_item.content)
message_type = await get_message_type_by_role(input_item.role)

View file

@ -12,7 +12,7 @@ from llama_stack.apis.agents import Agents, StepType
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference, SystemMessage, UserMessage
from llama_stack.apis.inference import Inference, OpenAISystemMessageParam, OpenAIUserMessageParam, UserMessage
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
@ -75,6 +75,13 @@ class MetaReferenceEvalImpl(
)
self.benchmarks[task_def.identifier] = task_def
async def unregister_benchmark(self, benchmark_id: str) -> None:
if benchmark_id in self.benchmarks:
del self.benchmarks[benchmark_id]
key = f"{EVAL_TASKS_PREFIX}{benchmark_id}"
await self.kvstore.delete(key)
async def run_eval(
self,
benchmark_id: str,
@ -152,31 +159,40 @@ class MetaReferenceEvalImpl(
) -> list[dict[str, Any]]:
candidate = benchmark_config.eval_candidate
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
sampling_params = {"max_tokens": candidate.sampling_params.max_tokens}
generations = []
for x in tqdm(input_rows):
if ColumnName.completion_input.value in x:
if candidate.sampling_params.stop:
sampling_params["stop"] = candidate.sampling_params.stop
input_content = json.loads(x[ColumnName.completion_input.value])
response = await self.inference_api.completion(
response = await self.inference_api.openai_completion(
model=candidate.model,
content=input_content,
sampling_params=candidate.sampling_params,
prompt=input_content,
**sampling_params,
)
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
generations.append({ColumnName.generated_answer.value: response.choices[0].text})
elif ColumnName.chat_completion_input.value in x:
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
input_messages = [
OpenAIUserMessageParam(**x) for x in chat_completion_input_json if x["role"] == "user"
]
messages = []
if candidate.system_message:
messages.append(candidate.system_message)
messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
messages += [OpenAISystemMessageParam(**x) for x in chat_completion_input_json if x["role"] == "system"]
messages += input_messages
response = await self.inference_api.chat_completion(
model_id=candidate.model,
response = await self.inference_api.openai_chat_completion(
model=candidate.model,
messages=messages,
sampling_params=candidate.sampling_params,
**sampling_params,
)
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
generations.append({ColumnName.generated_answer.value: response.choices[0].message.content})
else:
raise ValueError("Invalid input row")

View file

@ -9,11 +9,12 @@ import uuid
from pathlib import Path
from typing import Annotated
from fastapi import File, Form, Response, UploadFile
from fastapi import Depends, File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import (
ExpiresAfter,
Files,
ListOpenAIFileResponse,
OpenAIFileDeleteResponse,
@ -22,6 +23,7 @@ from llama_stack.apis.files import (
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
@ -44,7 +46,7 @@ class LocalfsFilesImpl(Files):
storage_path.mkdir(parents=True, exist_ok=True)
# Initialize SQL store for metadata
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store))
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
await self.sql_store.create_table(
"openai_files",
{
@ -74,7 +76,7 @@ class LocalfsFilesImpl(Files):
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
@ -86,14 +88,13 @@ class LocalfsFilesImpl(Files):
self,
file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after_anchor: Annotated[str | None, Form(alias="expires_after[anchor]")] = None,
expires_after_seconds: Annotated[int | None, Form(alias="expires_after[seconds]")] = None,
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
) -> OpenAIFileObject:
"""Upload a file that can be used across various endpoints."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
if expires_after_anchor is not None or expires_after_seconds is not None:
if expires_after is not None:
raise NotImplementedError("File expiration is not supported by this provider")
file_id = self._generate_file_id()
@ -150,7 +151,6 @@ class LocalfsFilesImpl(Files):
paginated_result = await self.sql_store.fetch_all(
table="openai_files",
policy=self.policy,
where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,

View file

@ -18,8 +18,6 @@ from llama_stack.apis.common.content_types import (
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
@ -219,41 +217,6 @@ class MetaReferenceInferenceImpl(
results = await self._nonstream_completion([request])
return results[0]
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
content_batch = [
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
]
request_batch = []
for content in content_batch:
request = CompletionRequest(
model=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
request = await convert_request_to_raw(request)
request_batch.append(request)
results = await self._nonstream_completion(request_batch)
return BatchCompletionResponse(batch=results)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer
@ -399,49 +362,6 @@ class MetaReferenceInferenceImpl(
results = await self._nonstream_chat_completion([request])
return results[0]
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> BatchChatCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request_batch = []
for messages in messages_batch:
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
logprobs=logprobs,
tool_config=tool_config or ToolConfig(),
)
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
request_batch.append(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
results = await self._nonstream_chat_completion(request_batch)
return BatchChatCompletionResponse(batch=results)
async def _nonstream_chat_completion(
self, request_batch: list[ChatCompletionRequest]
) -> list[ChatCompletionResponse]:

View file

@ -290,13 +290,13 @@ class LlamaGuardShield:
else:
shield_input_message = self.build_text_shield_input(messages)
# TODO: llama-stack inference protocol has issues with non-streaming inference code
response = await self.inference_api.chat_completion(
model_id=self.model,
response = await self.inference_api.openai_chat_completion(
model=self.model,
messages=[shield_input_message],
stream=False,
temperature=0.0, # default is 1, which is too high for safety
)
content = response.completion_message.content
content = response.choices[0].message.content
content = content.strip()
return self.get_shield_response(content)

View file

@ -63,6 +63,9 @@ class LlmAsJudgeScoringImpl(
async def register_scoring_function(self, function_def: ScoringFn) -> None:
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
self.llm_as_judge_fn.unregister_scoring_fn_def(scoring_fn_id)
async def score_batch(
self,
dataset_id: str,

View file

@ -224,10 +224,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
return _GLOBAL_STORAGE["gauges"][name]
def _log_metric(self, event: MetricEvent) -> None:
# Always log to console if console sink is enabled (debug)
if TelemetrySink.CONSOLE in self.config.sinks:
logger.debug(f"METRIC: {event.metric}={event.value} {event.unit} {event.attributes}")
# Add metric as an event to the current span
try:
with self._lock:

View file

@ -8,7 +8,7 @@
from jinja2 import Template
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import UserMessage
from llama_stack.apis.inference import OpenAIUserMessageParam
from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
@ -61,16 +61,16 @@ async def llm_rag_query_generator(
messages = [interleaved_content_as_str(content)]
template = Template(config.template)
content = template.render({"messages": messages})
rendered_content: str = template.render({"messages": messages})
model = config.model
message = UserMessage(content=content)
response = await inference_api.chat_completion(
model_id=model,
message = OpenAIUserMessageParam(content=rendered_content)
response = await inference_api.openai_chat_completion(
model=model,
messages=[message],
stream=False,
)
query = response.completion_message.content
query = response.choices[0].message.content
return query

View file

@ -45,10 +45,7 @@ from llama_stack.apis.vector_io import (
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
parse_data_url,
)
from llama_stack.providers.utils.memory.vector_store import parse_data_url
from .config import RagToolRuntimeConfig
from .context_retriever import generate_rag_query
@ -60,6 +57,47 @@ def make_random_string(length: int = 8):
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
"""Get raw binary data and mime type from a RAGDocument for file upload."""
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
parts = parse_data_url(doc.content.uri)
mime_type = parts["mimetype"]
data = parts["data"]
if parts["is_base64"]:
file_data = base64.b64decode(data)
else:
file_data = data.encode("utf-8")
return file_data, mime_type
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
r.raise_for_status()
mime_type = r.headers.get("content-type", "application/octet-stream")
return r.content, mime_type
else:
if isinstance(doc.content, str):
content_str = doc.content
else:
content_str = interleaved_content_as_str(doc.content)
if content_str.startswith("data:"):
parts = parse_data_url(content_str)
mime_type = parts["mimetype"]
data = parts["data"]
if parts["is_base64"]:
file_data = base64.b64decode(data)
else:
file_data = data.encode("utf-8")
return file_data, mime_type
else:
return content_str.encode("utf-8"), "text/plain"
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__(
self,
@ -95,46 +133,52 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
return
for doc in documents:
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
parts = parse_data_url(doc.content.uri)
file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode()
mime_type = parts["mimetype"]
else:
async with httpx.AsyncClient() as client:
response = await client.get(doc.content.uri)
file_data = response.content
mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream")
else:
content_str = await content_from_doc(doc)
file_data = content_str.encode("utf-8")
mime_type = doc.mime_type or "text/plain"
try:
try:
file_data, mime_type = await raw_data_from_doc(doc)
except Exception as e:
log.error(f"Failed to extract content from document {doc.document_id}: {e}")
continue
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
file_obj = io.BytesIO(file_data)
file_obj.name = filename
file_obj = io.BytesIO(file_data)
file_obj.name = filename
upload_file = UploadFile(file=file_obj, filename=filename)
upload_file = UploadFile(file=file_obj, filename=filename)
created_file = await self.files_api.openai_upload_file(
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
try:
created_file = await self.files_api.openai_upload_file(
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
except Exception as e:
log.error(f"Failed to upload file for document {doc.document_id}: {e}")
continue
chunking_strategy = VectorStoreChunkingStrategyStatic(
static=VectorStoreChunkingStrategyStaticConfig(
max_chunk_size_tokens=chunk_size_in_tokens,
chunk_overlap_tokens=chunk_size_in_tokens // 4,
chunking_strategy = VectorStoreChunkingStrategyStatic(
static=VectorStoreChunkingStrategyStaticConfig(
max_chunk_size_tokens=chunk_size_in_tokens,
chunk_overlap_tokens=chunk_size_in_tokens // 4,
)
)
)
await self.vector_io_api.openai_attach_file_to_vector_store(
vector_store_id=vector_db_id,
file_id=created_file.id,
attributes=doc.metadata,
chunking_strategy=chunking_strategy,
)
try:
await self.vector_io_api.openai_attach_file_to_vector_store(
vector_store_id=vector_db_id,
file_id=created_file.id,
attributes=doc.metadata,
chunking_strategy=chunking_strategy,
)
except Exception as e:
log.error(
f"Failed to attach file {created_file.id} to vector store {vector_db_id} for document {doc.document_id}: {e}"
)
continue
except Exception as e:
log.error(f"Unexpected error processing document {doc.document_id}: {e}")
continue
async def query(
self,
@ -274,7 +318,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else:
# handle someone passing an empty dict
query_config = RAGQueryConfig()
query = kwargs["query"]
@ -285,6 +328,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
)
return ToolInvocationResult(
content=result.content,
content=result.content or [],
metadata=result.metadata,
)