llama-stack/llama_stack/providers/utils/inference/stream_utils.py
ehhuang 549812f51e
feat: implement get chat completions APIs (#2200)
# What does this PR do?
* Provide sqlite implementation of the APIs introduced in
https://github.com/meta-llama/llama-stack/pull/2145.
* Introduced a SqlStore API: llama_stack/providers/utils/sqlstore/api.py
and the first Sqlite implementation
* Pagination support will be added in a future PR.

## Test Plan
Unit test on sql store:
<img width="1005" alt="image"
src="https://github.com/user-attachments/assets/9b8b7ec8-632b-4667-8127-5583426b2e29"
/>


Integration test:
```
INFERENCE_MODEL="llama3.2:3b-instruct-fp16" llama stack build --template ollama --image-type conda --run
```
```
LLAMA_STACK_CONFIG=http://localhost:5001 INFERENCE_MODEL="llama3.2:3b-instruct-fp16" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "llama3.2:3b-instruct-fp16" -k 'inference_store and openai'
```
2025-05-21 22:21:52 -07:00

129 lines
6.1 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from typing import Any
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIChoiceLogprobs,
OpenAIMessageParam,
)
from llama_stack.providers.utils.inference.inference_store import InferenceStore
async def stream_and_store_openai_completion(
provider_stream: AsyncIterator[OpenAIChatCompletionChunk],
model: str,
store: InferenceStore,
input_messages: list[OpenAIMessageParam],
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""
Wraps a provider's stream, yields chunks, and stores the full completion at the end.
"""
id = None
created = None
choices_data: dict[int, dict[str, Any]] = {}
try:
async for chunk in provider_stream:
if id is None and chunk.id:
id = chunk.id
if created is None and chunk.created:
created = chunk.created
if chunk.choices:
for choice_delta in chunk.choices:
idx = choice_delta.index
if idx not in choices_data:
choices_data[idx] = {
"content_parts": [],
"tool_calls_builder": {},
"finish_reason": None,
"logprobs_content_parts": [],
}
current_choice_data = choices_data[idx]
if choice_delta.delta:
delta = choice_delta.delta
if delta.content:
current_choice_data["content_parts"].append(delta.content)
if delta.tool_calls:
for tool_call_delta in delta.tool_calls:
tc_idx = tool_call_delta.index
if tc_idx not in current_choice_data["tool_calls_builder"]:
# Initialize with correct structure for _ToolCallBuilderData
current_choice_data["tool_calls_builder"][tc_idx] = {
"id": None,
"type": "function",
"function_name_parts": [],
"function_arguments_parts": [],
}
builder = current_choice_data["tool_calls_builder"][tc_idx]
if tool_call_delta.id:
builder["id"] = tool_call_delta.id
if tool_call_delta.type:
builder["type"] = tool_call_delta.type
if tool_call_delta.function:
if tool_call_delta.function.name:
builder["function_name_parts"].append(tool_call_delta.function.name)
if tool_call_delta.function.arguments:
builder["function_arguments_parts"].append(tool_call_delta.function.arguments)
if choice_delta.finish_reason:
current_choice_data["finish_reason"] = choice_delta.finish_reason
if choice_delta.logprobs and choice_delta.logprobs.content:
# Ensure that we are extending with the correct type
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
yield chunk
finally:
if id:
assembled_choices: list[OpenAIChoice] = []
for choice_idx, choice_data in choices_data.items():
content_str = "".join(choice_data["content_parts"])
assembled_tool_calls: list[OpenAIChatCompletionToolCall] = []
if choice_data["tool_calls_builder"]:
for tc_build_data in choice_data["tool_calls_builder"].values():
if tc_build_data["id"]:
func_name = "".join(tc_build_data["function_name_parts"])
func_args = "".join(tc_build_data["function_arguments_parts"])
assembled_tool_calls.append(
OpenAIChatCompletionToolCall(
id=tc_build_data["id"],
type=tc_build_data["type"], # No or "function" needed, already set
function=OpenAIChatCompletionToolCallFunction(name=func_name, arguments=func_args),
)
)
message = OpenAIAssistantMessageParam(
role="assistant",
content=content_str if content_str else None,
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
)
logprobs_content = choice_data["logprobs_content_parts"]
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
assembled_choices.append(
OpenAIChoice(
finish_reason=choice_data["finish_reason"],
index=choice_idx,
message=message,
logprobs=final_logprobs,
)
)
final_response = OpenAIChatCompletion(
id=id,
choices=assembled_choices,
created=created or int(datetime.now(timezone.utc).timestamp()),
model=model,
object="chat.completion",
)
await store.store_chat_completion(final_response, input_messages)