Merge remote-tracking branch 'origin/main' into TamiTakamiya/tool-param-definition-update

This commit is contained in:
Ashwin Bharambe 2025-09-27 10:47:08 -07:00
commit c1818350c8
479 changed files with 74743 additions and 8997 deletions

View file

@ -44,7 +44,7 @@ class LocalfsFilesImpl(Files):
storage_path.mkdir(parents=True, exist_ok=True)
# Initialize SQL store for metadata
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store))
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
await self.sql_store.create_table(
"openai_files",
{
@ -74,7 +74,7 @@ class LocalfsFilesImpl(Files):
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
@ -150,7 +150,6 @@ class LocalfsFilesImpl(Files):
paginated_result = await self.sql_store.fetch_all(
table="openai_files",
policy=self.policy,
where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,

View file

@ -18,8 +18,6 @@ from llama_stack.apis.common.content_types import (
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
@ -219,41 +217,6 @@ class MetaReferenceInferenceImpl(
results = await self._nonstream_completion([request])
return results[0]
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
content_batch = [
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
]
request_batch = []
for content in content_batch:
request = CompletionRequest(
model=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
request = await convert_request_to_raw(request)
request_batch.append(request)
results = await self._nonstream_completion(request_batch)
return BatchCompletionResponse(batch=results)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer
@ -399,49 +362,6 @@ class MetaReferenceInferenceImpl(
results = await self._nonstream_chat_completion([request])
return results[0]
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> BatchChatCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request_batch = []
for messages in messages_batch:
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
logprobs=logprobs,
tool_config=tool_config or ToolConfig(),
)
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
request_batch.append(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
results = await self._nonstream_chat_completion(request_batch)
return BatchChatCompletionResponse(batch=results)
async def _nonstream_chat_completion(
self, request_batch: list[ChatCompletionRequest]
) -> list[ChatCompletionResponse]: