mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
fix: address review comments
This commit addresses comments regarding the
OpenAI chat completion implementation in the meta_reference provider.
Tool Augmentation
- Add `augment_raw_messages_for_tools()` to properly inject tool definitions into prompts
- Support model-family-specific tool formats:
* Llama 3.1/3.2 multimodal: JsonCustomToolGenerator with JSON format
* Llama 3.2/3.3/4: PythonListCustomToolGenerator with Python list format
- Handle tool_choice hints (auto/required/specific tool)
- Preserve existing system messages while adding tool context
Streaming & Tool Call Detection
- Implement streaming support via `params.stream` with `_stream_chat_completion()`
- Add tool call detection by decoding assistant messages after generation
- Set proper `finish_reason` based on content ("stop" vs "tool_calls")
- Convert internal ToolCall format to OpenAI-compatible types
- Stream chunks incrementally with proper delta formatting
Type Corrections
- Fix response_format handling in generators.py to properly extract schema from
OpenAIJSONSchema TypedDict and use correct ResponseFormatType enum
- Use correct OpenAI types: OpenAIChatCompletionToolCall, OpenAIChunkChoice,
OpenAIChoiceDelta, OpenAIChatCompletionToolCallFunction
Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
dac1ff1f57
commit
1b77826aba
2 changed files with 345 additions and 9 deletions
|
|
@ -14,7 +14,9 @@ from llama_stack.apis.inference import (
|
||||||
GreedySamplingStrategy,
|
GreedySamplingStrategy,
|
||||||
JsonSchemaResponseFormat,
|
JsonSchemaResponseFormat,
|
||||||
OpenAIChatCompletionRequestWithExtraBody,
|
OpenAIChatCompletionRequestWithExtraBody,
|
||||||
|
OpenAIResponseFormatJSONSchema,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TopPSamplingStrategy,
|
TopPSamplingStrategy,
|
||||||
)
|
)
|
||||||
|
|
@ -163,7 +165,8 @@ class LlamaGenerator:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
if request.temperature is not None or request.top_p is not None:
|
if request.temperature is not None or request.top_p is not None:
|
||||||
sampling_params.strategy = TopPSamplingStrategy(
|
sampling_params.strategy = TopPSamplingStrategy(
|
||||||
temperature=request.temperature or 1.0, top_p=request.top_p or 1.0
|
temperature=request.temperature if request.temperature is not None else 1.0,
|
||||||
|
top_p=request.top_p if request.top_p is not None else 1.0,
|
||||||
)
|
)
|
||||||
if request.max_tokens:
|
if request.max_tokens:
|
||||||
sampling_params.max_tokens = request.max_tokens
|
sampling_params.max_tokens = request.max_tokens
|
||||||
|
|
@ -177,9 +180,12 @@ class LlamaGenerator:
|
||||||
# Get logits processor for response format
|
# Get logits processor for response format
|
||||||
logits_processor = None
|
logits_processor = None
|
||||||
if request.response_format:
|
if request.response_format:
|
||||||
if isinstance(request.response_format, dict) and request.response_format.get("type") == "json_schema":
|
if isinstance(request.response_format, OpenAIResponseFormatJSONSchema):
|
||||||
|
# Extract the actual schema from OpenAIJSONSchema TypedDict
|
||||||
|
schema_dict = request.response_format.json_schema.get("schema") or {}
|
||||||
json_schema_format = JsonSchemaResponseFormat(
|
json_schema_format = JsonSchemaResponseFormat(
|
||||||
type="json_schema", json_schema=request.response_format.get("json_schema", {})
|
type=ResponseFormatType.json_schema,
|
||||||
|
json_schema=schema_dict,
|
||||||
)
|
)
|
||||||
logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format)
|
logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,8 @@ from llama_stack.apis.inference import (
|
||||||
OpenAIChatCompletionUsage,
|
OpenAIChatCompletionUsage,
|
||||||
OpenAIChoice,
|
OpenAIChoice,
|
||||||
OpenAICompletionRequestWithExtraBody,
|
OpenAICompletionRequestWithExtraBody,
|
||||||
|
OpenAIUserMessageParam,
|
||||||
|
ToolChoice,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference.inference import (
|
from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
|
|
@ -24,12 +26,20 @@ from llama_stack.apis.inference.inference import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.models.llama.datatypes import RawMessage, RawTextItem, ToolDefinition
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||||
|
from llama_stack.models.llama.llama3.prompt_templates import (
|
||||||
|
JsonCustomToolGenerator,
|
||||||
|
SystemDefaultGenerator,
|
||||||
|
)
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||||
|
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||||
|
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
|
||||||
|
)
|
||||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.models.llama.sku_types import ModelFamily
|
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
SentenceTransformerEmbeddingMixin,
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
|
@ -49,6 +59,170 @@ log = get_logger(__name__, category="inference")
|
||||||
SEMAPHORE = asyncio.Semaphore(1)
|
SEMAPHORE = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_openai_tool_to_tool_definition(tool) -> ToolDefinition:
|
||||||
|
"""Convert OpenAI tool format to ToolDefinition format."""
|
||||||
|
# OpenAI tools have function.name and function.parameters
|
||||||
|
return ToolDefinition(
|
||||||
|
tool_name=tool.function.name,
|
||||||
|
description=tool.function.description or "",
|
||||||
|
parameters=tool.function.parameters or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tool_choice_prompt(tool_choice, tools) -> str:
|
||||||
|
"""Generate prompt text for tool_choice behavior."""
|
||||||
|
if not tool_choice or tool_choice == ToolChoice.auto or tool_choice == "auto":
|
||||||
|
return ""
|
||||||
|
elif tool_choice == ToolChoice.required or tool_choice == "required":
|
||||||
|
return "You MUST use one of the provided functions/tools to answer the user query."
|
||||||
|
elif tool_choice == ToolChoice.none or tool_choice == "none":
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
# Specific tool specified
|
||||||
|
return f"You MUST use the tool `{tool_choice}` to answer the user query."
|
||||||
|
|
||||||
|
|
||||||
|
def _raw_content_as_str(content) -> str:
|
||||||
|
"""Convert RawContent to string for system messages."""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
elif isinstance(content, RawTextItem):
|
||||||
|
return content.text
|
||||||
|
elif isinstance(content, list):
|
||||||
|
return "\n".join(_raw_content_as_str(c) for c in content)
|
||||||
|
else:
|
||||||
|
return "<media>"
|
||||||
|
|
||||||
|
|
||||||
|
def _augment_raw_messages_for_tools_llama_3_1(
|
||||||
|
raw_messages: list[RawMessage],
|
||||||
|
tools: list,
|
||||||
|
tool_choice,
|
||||||
|
) -> list[RawMessage]:
|
||||||
|
"""Augment raw messages with tool definitions for Llama 3.1 style models."""
|
||||||
|
messages = raw_messages.copy()
|
||||||
|
existing_system_message = None
|
||||||
|
if messages and messages[0].role == "system":
|
||||||
|
existing_system_message = messages.pop(0)
|
||||||
|
|
||||||
|
sys_content = ""
|
||||||
|
|
||||||
|
# Add tool definitions first (if present)
|
||||||
|
if tools:
|
||||||
|
# Convert OpenAI tools to ToolDefinitions
|
||||||
|
tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
|
||||||
|
|
||||||
|
# For OpenAI format, all tools are custom (have string names)
|
||||||
|
tool_gen = JsonCustomToolGenerator()
|
||||||
|
tool_template = tool_gen.gen(tool_definitions)
|
||||||
|
sys_content += tool_template.render()
|
||||||
|
sys_content += "\n"
|
||||||
|
|
||||||
|
# Add default system prompt
|
||||||
|
default_gen = SystemDefaultGenerator()
|
||||||
|
default_template = default_gen.gen()
|
||||||
|
sys_content += default_template.render()
|
||||||
|
|
||||||
|
# Add existing system message if present
|
||||||
|
if existing_system_message:
|
||||||
|
sys_content += "\n" + _raw_content_as_str(existing_system_message.content)
|
||||||
|
|
||||||
|
# Add tool choice prompt if needed
|
||||||
|
if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
|
||||||
|
sys_content += "\n" + tool_choice_prompt
|
||||||
|
|
||||||
|
# Create new system message
|
||||||
|
new_system_message = RawMessage(
|
||||||
|
role="system",
|
||||||
|
content=[RawTextItem(text=sys_content.strip())],
|
||||||
|
)
|
||||||
|
|
||||||
|
return [new_system_message] + messages
|
||||||
|
|
||||||
|
|
||||||
|
def _augment_raw_messages_for_tools_llama_4(
|
||||||
|
raw_messages: list[RawMessage],
|
||||||
|
tools: list,
|
||||||
|
tool_choice,
|
||||||
|
) -> list[RawMessage]:
|
||||||
|
"""Augment raw messages with tool definitions for Llama 4/3.2/3.3 style models."""
|
||||||
|
messages = raw_messages.copy()
|
||||||
|
existing_system_message = None
|
||||||
|
if messages and messages[0].role == "system":
|
||||||
|
existing_system_message = messages.pop(0)
|
||||||
|
|
||||||
|
sys_content = ""
|
||||||
|
|
||||||
|
# Add tool definitions if present
|
||||||
|
if tools:
|
||||||
|
# Convert OpenAI tools to ToolDefinitions
|
||||||
|
tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
|
||||||
|
|
||||||
|
# Use python_list format for Llama 4
|
||||||
|
tool_gen = PythonListCustomToolGeneratorLlama4()
|
||||||
|
system_prompt = None
|
||||||
|
if existing_system_message:
|
||||||
|
system_prompt = _raw_content_as_str(existing_system_message.content)
|
||||||
|
|
||||||
|
tool_template = tool_gen.gen(tool_definitions, system_prompt)
|
||||||
|
sys_content = tool_template.render()
|
||||||
|
elif existing_system_message:
|
||||||
|
# No tools, just use existing system message
|
||||||
|
sys_content = _raw_content_as_str(existing_system_message.content)
|
||||||
|
|
||||||
|
# Add tool choice prompt if needed
|
||||||
|
if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
|
||||||
|
sys_content += "\n" + tool_choice_prompt
|
||||||
|
|
||||||
|
if sys_content:
|
||||||
|
new_system_message = RawMessage(
|
||||||
|
role="system",
|
||||||
|
content=[RawTextItem(text=sys_content.strip())],
|
||||||
|
)
|
||||||
|
return [new_system_message] + messages
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def augment_raw_messages_for_tools(
|
||||||
|
raw_messages: list[RawMessage],
|
||||||
|
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||||
|
llama_model,
|
||||||
|
) -> list[RawMessage]:
|
||||||
|
"""Augment raw messages with tool definitions based on model family."""
|
||||||
|
if not params.tools:
|
||||||
|
return raw_messages
|
||||||
|
|
||||||
|
# Determine augmentation strategy based on model family
|
||||||
|
if llama_model.model_family == ModelFamily.llama3_1 or (
|
||||||
|
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
|
||||||
|
):
|
||||||
|
# Llama 3.1 and Llama 3.2 multimodal use JSON format
|
||||||
|
return _augment_raw_messages_for_tools_llama_3_1(
|
||||||
|
raw_messages,
|
||||||
|
params.tools,
|
||||||
|
params.tool_choice,
|
||||||
|
)
|
||||||
|
elif llama_model.model_family in (
|
||||||
|
ModelFamily.llama3_2,
|
||||||
|
ModelFamily.llama3_3,
|
||||||
|
ModelFamily.llama4,
|
||||||
|
):
|
||||||
|
# Llama 3.2/3.3/4 use python_list format
|
||||||
|
return _augment_raw_messages_for_tools_llama_4(
|
||||||
|
raw_messages,
|
||||||
|
params.tools,
|
||||||
|
params.tool_choice,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Default to Llama 3.1 style
|
||||||
|
return _augment_raw_messages_for_tools_llama_3_1(
|
||||||
|
raw_messages,
|
||||||
|
params.tools,
|
||||||
|
params.tool_choice,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
||||||
return LlamaGenerator(config, model_id, llama_model)
|
return LlamaGenerator(config, model_id, llama_model)
|
||||||
|
|
||||||
|
|
@ -141,7 +315,6 @@ class MetaReferenceInferenceImpl(
|
||||||
self.llama_model = llama_model
|
self.llama_model = llama_model
|
||||||
|
|
||||||
log.info("Warming up...")
|
log.info("Warming up...")
|
||||||
from llama_stack.apis.inference import OpenAIUserMessageParam
|
|
||||||
|
|
||||||
await self.openai_chat_completion(
|
await self.openai_chat_completion(
|
||||||
params=OpenAIChatCompletionRequestWithExtraBody(
|
params=OpenAIChatCompletionRequestWithExtraBody(
|
||||||
|
|
@ -167,10 +340,17 @@ class MetaReferenceInferenceImpl(
|
||||||
self.check_model(params)
|
self.check_model(params)
|
||||||
|
|
||||||
# Convert OpenAI messages to RawMessages
|
# Convert OpenAI messages to RawMessages
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import convert_openai_message_to_raw_message
|
from llama_stack.models.llama.datatypes import StopReason
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
convert_openai_message_to_raw_message,
|
||||||
|
decode_assistant_message,
|
||||||
|
)
|
||||||
|
|
||||||
raw_messages = [await convert_openai_message_to_raw_message(msg) for msg in params.messages]
|
raw_messages = [await convert_openai_message_to_raw_message(msg) for msg in params.messages]
|
||||||
|
|
||||||
|
# Augment messages with tool definitions if tools are present
|
||||||
|
raw_messages = augment_raw_messages_for_tools(raw_messages, params, self.llama_model)
|
||||||
|
|
||||||
# Call generator's chat_completion method (works for both single-GPU and model-parallel)
|
# Call generator's chat_completion method (works for both single-GPU and model-parallel)
|
||||||
if isinstance(self.generator, LlamaGenerator):
|
if isinstance(self.generator, LlamaGenerator):
|
||||||
generator = self.generator.chat_completion(params, raw_messages)
|
generator = self.generator.chat_completion(params, raw_messages)
|
||||||
|
|
@ -178,14 +358,56 @@ class MetaReferenceInferenceImpl(
|
||||||
# Model parallel: submit task to process group
|
# Model parallel: submit task to process group
|
||||||
generator = self.generator.group.run_inference(("chat_completion", [params, raw_messages]))
|
generator = self.generator.group.run_inference(("chat_completion", [params, raw_messages]))
|
||||||
|
|
||||||
# Collect all generated text
|
# Check if streaming is requested
|
||||||
|
if params.stream:
|
||||||
|
return self._stream_chat_completion(generator, params)
|
||||||
|
|
||||||
|
# Non-streaming: collect all generated text
|
||||||
generated_text = ""
|
generated_text = ""
|
||||||
for result_batch in generator:
|
for result_batch in generator:
|
||||||
for result in result_batch:
|
for result in result_batch:
|
||||||
if not result.ignore_token and result.source == "output":
|
if not result.ignore_token and result.source == "output":
|
||||||
generated_text += result.text
|
generated_text += result.text
|
||||||
|
|
||||||
|
# Decode assistant message to extract tool calls and determine stop_reason
|
||||||
|
# Default to end_of_turn if generation completed normally
|
||||||
|
decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
|
||||||
|
|
||||||
|
# Convert tool calls to OpenAI format
|
||||||
|
openai_tool_calls = None
|
||||||
|
if decoded_message.tool_calls:
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
OpenAIChatCompletionToolCall,
|
||||||
|
OpenAIChatCompletionToolCallFunction,
|
||||||
|
)
|
||||||
|
|
||||||
|
openai_tool_calls = [
|
||||||
|
OpenAIChatCompletionToolCall(
|
||||||
|
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
|
||||||
|
id=f"call_{uuid.uuid4().hex[:24]}",
|
||||||
|
type="function",
|
||||||
|
function=OpenAIChatCompletionToolCallFunction(
|
||||||
|
name=str(tc.tool_name),
|
||||||
|
arguments=tc.arguments,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for tc in decoded_message.tool_calls
|
||||||
|
]
|
||||||
|
|
||||||
|
# Determine finish_reason based on whether tool calls are present
|
||||||
|
finish_reason = "tool_calls" if openai_tool_calls else "stop"
|
||||||
|
|
||||||
|
# Extract content from decoded message
|
||||||
|
content = ""
|
||||||
|
if isinstance(decoded_message.content, str):
|
||||||
|
content = decoded_message.content
|
||||||
|
elif isinstance(decoded_message.content, list):
|
||||||
|
for item in decoded_message.content:
|
||||||
|
if isinstance(item, RawTextItem):
|
||||||
|
content += item.text
|
||||||
|
|
||||||
# Create OpenAI response
|
# Create OpenAI response
|
||||||
|
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
|
||||||
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||||
created = int(time.time())
|
created = int(time.time())
|
||||||
|
|
||||||
|
|
@ -199,9 +421,10 @@ class MetaReferenceInferenceImpl(
|
||||||
index=0,
|
index=0,
|
||||||
message=OpenAIAssistantMessageParam(
|
message=OpenAIAssistantMessageParam(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=generated_text,
|
content=content,
|
||||||
|
tool_calls=openai_tool_calls,
|
||||||
),
|
),
|
||||||
finish_reason="stop",
|
finish_reason=finish_reason,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
|
@ -211,3 +434,110 @@ class MetaReferenceInferenceImpl(
|
||||||
total_tokens=0, # TODO: calculate properly
|
total_tokens=0, # TODO: calculate properly
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _stream_chat_completion(
|
||||||
|
self,
|
||||||
|
generator,
|
||||||
|
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||||
|
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
|
"""Stream chat completion chunks as they're generated."""
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
OpenAIChatCompletionChunk,
|
||||||
|
OpenAIChatCompletionToolCall,
|
||||||
|
OpenAIChatCompletionToolCallFunction,
|
||||||
|
OpenAIChoiceDelta,
|
||||||
|
OpenAIChunkChoice,
|
||||||
|
)
|
||||||
|
from llama_stack.models.llama.datatypes import StopReason
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message
|
||||||
|
|
||||||
|
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||||
|
created = int(time.time())
|
||||||
|
generated_text = ""
|
||||||
|
|
||||||
|
# Yield chunks as tokens are generated
|
||||||
|
for result_batch in generator:
|
||||||
|
for result in result_batch:
|
||||||
|
if result.ignore_token or result.source != "output":
|
||||||
|
continue
|
||||||
|
|
||||||
|
generated_text += result.text
|
||||||
|
|
||||||
|
# Yield delta chunk with the new text
|
||||||
|
chunk = OpenAIChatCompletionChunk(
|
||||||
|
id=response_id,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
created=created,
|
||||||
|
model=params.model,
|
||||||
|
choices=[
|
||||||
|
OpenAIChunkChoice(
|
||||||
|
index=0,
|
||||||
|
delta=OpenAIChoiceDelta(
|
||||||
|
role="assistant",
|
||||||
|
content=result.text,
|
||||||
|
),
|
||||||
|
finish_reason="",
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# After generation completes, decode the full message to extract tool calls
|
||||||
|
decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
|
||||||
|
|
||||||
|
# If tool calls are present, yield a final chunk with tool_calls
|
||||||
|
if decoded_message.tool_calls:
|
||||||
|
openai_tool_calls = [
|
||||||
|
OpenAIChatCompletionToolCall(
|
||||||
|
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
|
||||||
|
id=f"call_{uuid.uuid4().hex[:24]}",
|
||||||
|
type="function",
|
||||||
|
function=OpenAIChatCompletionToolCallFunction(
|
||||||
|
name=str(tc.tool_name),
|
||||||
|
arguments=tc.arguments,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for tc in decoded_message.tool_calls
|
||||||
|
]
|
||||||
|
|
||||||
|
# Yield chunk with tool_calls
|
||||||
|
chunk = OpenAIChatCompletionChunk(
|
||||||
|
id=response_id,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
created=created,
|
||||||
|
model=params.model,
|
||||||
|
choices=[
|
||||||
|
OpenAIChunkChoice(
|
||||||
|
index=0,
|
||||||
|
delta=OpenAIChoiceDelta(
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=openai_tool_calls,
|
||||||
|
),
|
||||||
|
finish_reason="",
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
finish_reason = "tool_calls"
|
||||||
|
else:
|
||||||
|
finish_reason = "stop"
|
||||||
|
|
||||||
|
# Yield final chunk with finish_reason
|
||||||
|
final_chunk = OpenAIChatCompletionChunk(
|
||||||
|
id=response_id,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
created=created,
|
||||||
|
model=params.model,
|
||||||
|
choices=[
|
||||||
|
OpenAIChunkChoice(
|
||||||
|
index=0,
|
||||||
|
delta=OpenAIChoiceDelta(),
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
yield final_chunk
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue