mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 09:03:52 +00:00
make TGI work well
This commit is contained in:
parent
e58c7f6c37
commit
021dd0d35d
9 changed files with 617 additions and 326 deletions
|
|
@ -6,55 +6,7 @@
|
|||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from typing import AsyncGenerator, Dict, Iterable, List, Optional, Union
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageToolCall,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
|
||||
)
|
||||
from openai.types.chat.chat_completion import (
|
||||
Choice as OpenAIChoice,
|
||||
)
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
|
||||
)
|
||||
from openai.types.chat.chat_completion_content_part_image_param import (
|
||||
ImageURL as OpenAIImageURL,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
Function as OpenAIFunction,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from typing import AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
|
|
@ -71,11 +23,14 @@ from llama_stack.apis.inference import (
|
|||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
JsonSchemaResponseFormat,
|
||||
Message,
|
||||
SystemMessage,
|
||||
TokenLogProbs,
|
||||
ToolChoice,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
|
@ -94,6 +49,32 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
decode_assistant_message,
|
||||
)
|
||||
|
||||
from openai import AsyncStream, Stream
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
|
||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
|
||||
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
|
||||
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
|
||||
ChatCompletionMessageToolCall as OpenAIChatCompletionMessageToolCall,
|
||||
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCallParam,
|
||||
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
|
||||
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
|
||||
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
|
||||
)
|
||||
from openai.types.chat.chat_completion import (
|
||||
Choice as OpenAIChoice,
|
||||
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
|
||||
)
|
||||
from openai.types.chat.chat_completion_content_part_image_param import (
|
||||
ImageURL as OpenAIImageURL,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
Function as OpenAIFunction,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -188,12 +169,16 @@ def convert_openai_completion_logprobs(
|
|||
if logprobs.tokens and logprobs.token_logprobs:
|
||||
return [
|
||||
TokenLogProbs(logprobs_by_token={token: token_lp})
|
||||
for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs, strict=False)
|
||||
for token, token_lp in zip(
|
||||
logprobs.tokens, logprobs.token_logprobs, strict=False
|
||||
)
|
||||
]
|
||||
return None
|
||||
|
||||
|
||||
def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]):
|
||||
def convert_openai_completion_logprobs_stream(
|
||||
text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]
|
||||
):
|
||||
if logprobs is None:
|
||||
return None
|
||||
if isinstance(logprobs, float):
|
||||
|
|
@ -238,7 +223,9 @@ def process_chat_completion_response(
|
|||
if not choice.message or not choice.message.tool_calls:
|
||||
raise ValueError("Tool calls are not present in the response")
|
||||
|
||||
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls]
|
||||
tool_calls = [
|
||||
convert_tool_call(tool_call) for tool_call in choice.message.tool_calls
|
||||
]
|
||||
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
|
||||
# If we couldn't parse a tool call, jsonify the tool calls and return them
|
||||
return ChatCompletionResponse(
|
||||
|
|
@ -262,7 +249,9 @@ def process_chat_completion_response(
|
|||
|
||||
# TODO: This does not work well with tool calls for vLLM remote provider
|
||||
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
|
||||
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
|
||||
raw_message = decode_assistant_message(
|
||||
text_from_choice(choice), get_stop_reason(choice.finish_reason)
|
||||
)
|
||||
|
||||
# NOTE: If we do not set tools in chat-completion request, we should not
|
||||
# expect the ToolCall in the response. Instead, we should return the raw
|
||||
|
|
@ -463,13 +452,17 @@ async def process_chat_completion_stream_response(
|
|||
)
|
||||
|
||||
|
||||
async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict:
|
||||
async def convert_message_to_openai_dict(
|
||||
message: Message, download: bool = False
|
||||
) -> dict:
|
||||
async def _convert_content(content) -> dict:
|
||||
if isinstance(content, ImageContentItem):
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": await convert_image_content_to_url(content, download=download),
|
||||
"url": await convert_image_content_to_url(
|
||||
content, download=download
|
||||
),
|
||||
},
|
||||
}
|
||||
else:
|
||||
|
|
@ -548,7 +541,9 @@ async def convert_message_to_openai_dict_new(
|
|||
elif isinstance(content_, ImageContentItem):
|
||||
return OpenAIChatCompletionContentPartImageParam(
|
||||
type="image_url",
|
||||
image_url=OpenAIImageURL(url=await convert_image_content_to_url(content_)),
|
||||
image_url=OpenAIImageURL(
|
||||
url=await convert_image_content_to_url(content_)
|
||||
),
|
||||
)
|
||||
elif isinstance(content_, list):
|
||||
return [await impl(item) for item in content_]
|
||||
|
|
@ -574,14 +569,32 @@ async def convert_message_to_openai_dict_new(
|
|||
role="assistant",
|
||||
content=await _convert_message_content(message.content),
|
||||
tool_calls=[
|
||||
OpenAIChatCompletionMessageToolCall(
|
||||
id=tool.call_id,
|
||||
function=OpenAIFunction(
|
||||
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
|
||||
arguments=json.dumps(tool.arguments),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
# OpenAIChatCompletionMessageToolCall(
|
||||
# id=tool.call_id,
|
||||
# function=OpenAIFunction(
|
||||
# name=(
|
||||
# tool.tool_name
|
||||
# if not isinstance(tool.tool_name, BuiltinTool)
|
||||
# else tool.tool_name.value
|
||||
# ),
|
||||
# arguments=json.dumps(tool.arguments),
|
||||
# ),
|
||||
# type="function",
|
||||
# )
|
||||
# using a dict instead of OpenAIChatCompletionMessageToolCall object
|
||||
# as it fails to get json encoded
|
||||
{
|
||||
"id": tool.call_id,
|
||||
"function": {
|
||||
"name": (
|
||||
tool.tool_name
|
||||
if not isinstance(tool.tool_name, BuiltinTool)
|
||||
else tool.tool_name.value
|
||||
),
|
||||
"arguments": json.dumps(tool.arguments),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
for tool in message.tool_calls
|
||||
]
|
||||
or None,
|
||||
|
|
@ -604,7 +617,7 @@ async def convert_message_to_openai_dict_new(
|
|||
|
||||
|
||||
def convert_tool_call(
|
||||
tool_call: ChatCompletionMessageToolCall,
|
||||
tool_call: OpenAIChatCompletionMessageToolCall,
|
||||
) -> Union[ToolCall, UnparseableToolCall]:
|
||||
"""
|
||||
Convert a ChatCompletionMessageToolCall tool call to either a
|
||||
|
|
@ -696,7 +709,11 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
|||
properties = parameters["properties"]
|
||||
required = []
|
||||
for param_name, param in tool.parameters.items():
|
||||
properties[param_name] = {"type": PYTHON_TYPE_TO_LITELLM_TYPE.get(param.param_type, param.param_type)}
|
||||
properties[param_name] = {
|
||||
"type": PYTHON_TYPE_TO_LITELLM_TYPE.get(
|
||||
param.param_type, param.param_type
|
||||
)
|
||||
}
|
||||
if param.description:
|
||||
properties[param_name].update(description=param.description)
|
||||
if param.default:
|
||||
|
|
@ -762,15 +779,30 @@ def _convert_openai_tool_calls(
|
|||
if not tool_calls:
|
||||
return [] # CompletionMessage tool_calls is not optional
|
||||
|
||||
return [
|
||||
ToolCall(
|
||||
call_id=call.id,
|
||||
tool_name=call.function.name,
|
||||
arguments=json.loads(call.function.arguments),
|
||||
arguments_json=call.function.arguments,
|
||||
ls_tool_calls = []
|
||||
for call in tool_calls:
|
||||
args = call.function.arguments
|
||||
# TGI is sending a dict instead of a json string
|
||||
# While OpenAI spec expects a json string
|
||||
if isinstance(args, str):
|
||||
arguments = json.loads(args)
|
||||
arguments_json = args
|
||||
elif isinstance(args, dict):
|
||||
arguments = args
|
||||
arguments_json = json.dumps(args)
|
||||
else:
|
||||
raise ValueError(f"Unsupported arguments type: {type(args)}")
|
||||
|
||||
ls_tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=call.id,
|
||||
tool_name=call.function.name,
|
||||
arguments=arguments,
|
||||
arguments_json=arguments_json,
|
||||
)
|
||||
)
|
||||
for call in tool_calls
|
||||
]
|
||||
|
||||
return ls_tool_calls
|
||||
|
||||
|
||||
def _convert_openai_logprobs(
|
||||
|
|
@ -802,7 +834,11 @@ def _convert_openai_logprobs(
|
|||
return None
|
||||
|
||||
return [
|
||||
TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs})
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
logprobs.token: logprobs.logprob for logprobs in content.top_logprobs
|
||||
}
|
||||
)
|
||||
for content in logprobs.content
|
||||
]
|
||||
|
||||
|
|
@ -840,14 +876,17 @@ def convert_openai_chat_completion_choice(
|
|||
end_of_message = "end_of_message"
|
||||
out_of_tokens = "out_of_tokens"
|
||||
"""
|
||||
assert hasattr(choice, "message") and choice.message, "error in server response: message not found"
|
||||
assert hasattr(choice, "finish_reason") and choice.finish_reason, (
|
||||
"error in server response: finish_reason not found"
|
||||
)
|
||||
assert (
|
||||
hasattr(choice, "message") and choice.message
|
||||
), "error in server response: message not found"
|
||||
assert (
|
||||
hasattr(choice, "finish_reason") and choice.finish_reason
|
||||
), "error in server response: finish_reason not found"
|
||||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=choice.message.content or "", # CompletionMessage content is not optional
|
||||
content=choice.message.content
|
||||
or "", # CompletionMessage content is not optional
|
||||
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
||||
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
|
||||
),
|
||||
|
|
@ -856,13 +895,24 @@ def convert_openai_chat_completion_choice(
|
|||
|
||||
|
||||
async def convert_openai_chat_completion_stream(
|
||||
stream: AsyncStream[OpenAIChatCompletionChunk],
|
||||
stream: Union[
|
||||
AsyncStream[OpenAIChatCompletionChunk], Stream[OpenAIChatCompletionChunk]
|
||||
],
|
||||
enable_incremental_tool_calls: bool,
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
"""
|
||||
Convert a stream of OpenAI chat completion chunks into a stream
|
||||
of ChatCompletionResponseStreamChunk.
|
||||
"""
|
||||
|
||||
async def yield_from_stream(stream):
|
||||
if isinstance(stream, AsyncGenerator):
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
elif isinstance(stream, Generator):
|
||||
for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
|
|
@ -874,7 +924,7 @@ async def convert_openai_chat_completion_stream(
|
|||
stop_reason = None
|
||||
tool_call_idx_to_buffer = {}
|
||||
|
||||
async for chunk in stream:
|
||||
async for chunk in yield_from_stream(stream):
|
||||
choice = chunk.choices[0] # assuming only one choice per chunk
|
||||
|
||||
# we assume there's only one finish_reason in the stream
|
||||
|
|
@ -916,12 +966,60 @@ async def convert_openai_chat_completion_stream(
|
|||
)
|
||||
)
|
||||
else:
|
||||
for tool_call in choice.delta.tool_calls:
|
||||
idx = tool_call.index if hasattr(tool_call, "index") else 0
|
||||
if isinstance(choice.delta.tool_calls, list):
|
||||
tool_calls = choice.delta.tool_calls
|
||||
for tool_call in tool_calls:
|
||||
idx = tool_call.index if hasattr(tool_call, "index") else 0
|
||||
|
||||
if idx not in tool_call_idx_to_buffer:
|
||||
tool_call_idx_to_buffer[idx] = {
|
||||
"call_id": tool_call.id,
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
buffer = tool_call_idx_to_buffer[idx]
|
||||
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
buffer["name"] = tool_call.function.name
|
||||
delta = f"{buffer['name']}("
|
||||
buffer["content"] += delta
|
||||
|
||||
if tool_call.function.arguments:
|
||||
delta = tool_call.function.arguments
|
||||
buffer["arguments"] += delta
|
||||
buffer["content"] += delta
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=delta,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(logprobs),
|
||||
)
|
||||
)
|
||||
# TGI streams a non-openai compat response
|
||||
elif isinstance(choice.delta.tool_calls, dict):
|
||||
# tool_calls is a dict of the format
|
||||
# {
|
||||
# 'index': 0,
|
||||
# 'id': '',
|
||||
# 'type': 'function',
|
||||
# 'function': {
|
||||
# 'name': None,
|
||||
# 'arguments': '{"'
|
||||
# }
|
||||
# }
|
||||
tool_call = choice.delta.tool_calls
|
||||
idx = tool_call["index"] if "index" in tool_call else 0
|
||||
|
||||
if idx not in tool_call_idx_to_buffer:
|
||||
tool_call_idx_to_buffer[idx] = {
|
||||
"call_id": tool_call.id,
|
||||
"call_id": tool_call["id"],
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
"content": "",
|
||||
|
|
@ -929,14 +1027,15 @@ async def convert_openai_chat_completion_stream(
|
|||
|
||||
buffer = tool_call_idx_to_buffer[idx]
|
||||
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
buffer["name"] = tool_call.function.name
|
||||
if "function" in tool_call:
|
||||
function = tool_call["function"]
|
||||
if function["name"]:
|
||||
buffer["name"] = function["name"]
|
||||
delta = f"{buffer['name']}("
|
||||
buffer["content"] += delta
|
||||
|
||||
if tool_call.function.arguments:
|
||||
delta = tool_call.function.arguments
|
||||
if function["arguments"]:
|
||||
delta = function["arguments"]
|
||||
buffer["arguments"] += delta
|
||||
buffer["content"] += delta
|
||||
|
||||
|
|
@ -994,7 +1093,6 @@ async def convert_openai_chat_completion_stream(
|
|||
)
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Failed to parse arguments: {e}")
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
|
|
@ -1005,6 +1103,51 @@ async def convert_openai_chat_completion_stream(
|
|||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# name is None but we have the arguments contain the entire function call
|
||||
# example response where arguments is -->
|
||||
# '{"function": {"_name": "get_weather", "location": "San Francisco, CA"}}<|eot_id|>'
|
||||
# - parse the arguments
|
||||
# - build try to build ToolCall and return it or return the content as is
|
||||
|
||||
if buffer["arguments"]:
|
||||
arguments = buffer["arguments"]
|
||||
# remove the eot_id and eom_id from the arguments
|
||||
if arguments.endswith("<|eom_id|>"):
|
||||
arguments = arguments[: -len("<|eom_id|>")]
|
||||
if arguments.endswith("<|eot_id|>"):
|
||||
arguments = arguments[: -len("<|eot_id|>")]
|
||||
|
||||
arguments = json.loads(arguments)
|
||||
try:
|
||||
tool_name = arguments["function"].pop("_name", None)
|
||||
parsed_tool_call = ToolCall(
|
||||
call_id=buffer["call_id"],
|
||||
tool_name=tool_name,
|
||||
arguments=arguments["function"],
|
||||
arguments_json=json.dumps(arguments["function"]),
|
||||
)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=parsed_tool_call,
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
except (KeyError, json.JSONDecodeError) as e:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=buffer["content"],
|
||||
parse_status=ToolCallParseStatus.failed,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
|
@ -1013,3 +1156,113 @@ async def convert_openai_chat_completion_stream(
|
|||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def convert_completion_request_to_openai_params(
|
||||
request: CompletionRequest,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
|
||||
"""
|
||||
input_dict = {}
|
||||
if request.logprobs:
|
||||
input_dict["logprobs"] = request.logprobs.top_k
|
||||
|
||||
return {
|
||||
"model": request.model,
|
||||
"prompt": request.content,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
"n": 1,
|
||||
}
|
||||
|
||||
|
||||
async def convert_chat_completion_request_to_openai_params(
|
||||
request: ChatCompletionRequest,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert a ChatCompletionRequest to an OpenAI chat completion request.
|
||||
"""
|
||||
input_dict = {}
|
||||
|
||||
input_dict["messages"] = [
|
||||
await convert_message_to_openai_dict_new(m) for m in request.messages
|
||||
]
|
||||
if fmt := request.response_format:
|
||||
if not isinstance(fmt, JsonSchemaResponseFormat):
|
||||
raise ValueError(
|
||||
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
|
||||
)
|
||||
|
||||
fmt = fmt.json_schema
|
||||
name = fmt["title"]
|
||||
del fmt["title"]
|
||||
fmt["additionalProperties"] = False
|
||||
|
||||
# Apply additionalProperties: False recursively to all objects
|
||||
fmt = _add_additional_properties_recursive(fmt)
|
||||
|
||||
from rich.pretty import pprint
|
||||
|
||||
pprint(fmt)
|
||||
|
||||
input_dict["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": name,
|
||||
"schema": fmt,
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
|
||||
if request.tools:
|
||||
input_dict["tools"] = [
|
||||
convert_tooldef_to_openai_tool(tool) for tool in request.tools
|
||||
]
|
||||
if request.tool_config.tool_choice:
|
||||
input_dict["tool_choice"] = (
|
||||
request.tool_config.tool_choice.value
|
||||
if isinstance(request.tool_config.tool_choice, ToolChoice)
|
||||
else request.tool_config.tool_choice
|
||||
)
|
||||
|
||||
return {
|
||||
"model": request.model,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
"n": 1,
|
||||
}
|
||||
|
||||
|
||||
def _add_additional_properties_recursive(schema):
|
||||
"""
|
||||
Recursively add `additionalProperties: False` to all object schemas
|
||||
"""
|
||||
if isinstance(schema, dict):
|
||||
if schema.get("type") == "object":
|
||||
schema["additionalProperties"] = False
|
||||
|
||||
# Add required field with all property keys if properties exist
|
||||
if "properties" in schema and schema["properties"]:
|
||||
schema["required"] = list(schema["properties"].keys())
|
||||
|
||||
if "properties" in schema:
|
||||
for prop_schema in schema["properties"].values():
|
||||
_add_additional_properties_recursive(prop_schema)
|
||||
|
||||
for key in ["anyOf", "allOf", "oneOf"]:
|
||||
if key in schema:
|
||||
for sub_schema in schema[key]:
|
||||
_add_additional_properties_recursive(sub_schema)
|
||||
|
||||
if "not" in schema:
|
||||
_add_additional_properties_recursive(schema["not"])
|
||||
|
||||
# Handle $defs/$ref
|
||||
if "$defs" in schema:
|
||||
for def_schema in schema["$defs"].values():
|
||||
_add_additional_properties_recursive(def_schema)
|
||||
|
||||
return schema
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue