mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge pull request #5368 from BerriAI/litellm_vertex_function_support
feat(vertex_httpx.py): support 'functions' param for gemini google ai studio + vertex ai
This commit is contained in:
commit
c503ff435e
7 changed files with 263 additions and 84 deletions
|
@ -30,6 +30,7 @@ from litellm.types.llms.openai import (
|
|||
ChatCompletionResponseMessage,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
)
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
|
@ -296,11 +297,50 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
|
|||
"stream",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"functions",
|
||||
"response_format",
|
||||
"n",
|
||||
"stop",
|
||||
]
|
||||
|
||||
def _map_function(self, value: List[dict]) -> List[Tools]:
|
||||
gtool_func_declarations = []
|
||||
googleSearchRetrieval: Optional[dict] = None
|
||||
|
||||
for tool in value:
|
||||
openai_function_object: Optional[ChatCompletionToolParamFunctionChunk] = (
|
||||
None
|
||||
)
|
||||
if "function" in tool: # tools list
|
||||
openai_function_object = ChatCompletionToolParamFunctionChunk( # type: ignore
|
||||
**tool["function"]
|
||||
)
|
||||
elif "name" in tool: # functions list
|
||||
openai_function_object = ChatCompletionToolParamFunctionChunk(**tool) # type: ignore
|
||||
|
||||
# check if grounding
|
||||
if tool.get("googleSearchRetrieval", None) is not None:
|
||||
googleSearchRetrieval = tool["googleSearchRetrieval"]
|
||||
elif openai_function_object is not None:
|
||||
gtool_func_declaration = FunctionDeclaration(
|
||||
name=openai_function_object["name"],
|
||||
description=openai_function_object.get("description", ""),
|
||||
parameters=openai_function_object.get("parameters", {}),
|
||||
)
|
||||
gtool_func_declarations.append(gtool_func_declaration)
|
||||
else:
|
||||
# assume it's a provider-specific param
|
||||
verbose_logger.warning(
|
||||
"Invalid tool={}. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request."
|
||||
)
|
||||
|
||||
_tools = Tools(
|
||||
function_declarations=gtool_func_declarations,
|
||||
)
|
||||
if googleSearchRetrieval is not None:
|
||||
_tools["googleSearchRetrieval"] = googleSearchRetrieval
|
||||
return [_tools]
|
||||
|
||||
def map_tool_choice_values(
|
||||
self, model: str, tool_choice: Union[str, dict]
|
||||
) -> Optional[ToolConfig]:
|
||||
|
@ -363,26 +403,11 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
|
|||
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
|
||||
optional_params["response_mime_type"] = "application/json"
|
||||
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
|
||||
if param == "tools" and isinstance(value, list):
|
||||
gtool_func_declarations = []
|
||||
for tool in value:
|
||||
_parameters = tool.get("function", {}).get("parameters", {})
|
||||
_properties = _parameters.get("properties", {})
|
||||
if isinstance(_properties, dict):
|
||||
for _, _property in _properties.items():
|
||||
if "enum" in _property and "format" not in _property:
|
||||
_property["format"] = "enum"
|
||||
|
||||
gtool_func_declaration = FunctionDeclaration(
|
||||
name=tool["function"]["name"],
|
||||
description=tool["function"].get("description", ""),
|
||||
)
|
||||
if len(_parameters.keys()) > 0:
|
||||
gtool_func_declaration["parameters"] = _parameters
|
||||
gtool_func_declarations.append(gtool_func_declaration)
|
||||
optional_params["tools"] = [
|
||||
Tools(function_declarations=gtool_func_declarations)
|
||||
]
|
||||
if (param == "tools" or param == "functions") and isinstance(value, list):
|
||||
optional_params["tools"] = self._map_function(value=value)
|
||||
optional_params["litellm_param_is_function_call"] = (
|
||||
True if param == "functions" else False
|
||||
)
|
||||
if param == "tool_choice" and (
|
||||
isinstance(value, str) or isinstance(value, dict)
|
||||
):
|
||||
|
@ -506,6 +531,7 @@ class VertexGeminiConfig:
|
|||
"max_tokens",
|
||||
"stream",
|
||||
"tools",
|
||||
"functions",
|
||||
"tool_choice",
|
||||
"response_format",
|
||||
"n",
|
||||
|
@ -541,6 +567,44 @@ class VertexGeminiConfig:
|
|||
status_code=400,
|
||||
)
|
||||
|
||||
def _map_function(self, value: List[dict]) -> List[Tools]:
|
||||
gtool_func_declarations = []
|
||||
googleSearchRetrieval: Optional[dict] = None
|
||||
|
||||
for tool in value:
|
||||
openai_function_object: Optional[ChatCompletionToolParamFunctionChunk] = (
|
||||
None
|
||||
)
|
||||
if "function" in tool: # tools list
|
||||
openai_function_object = ChatCompletionToolParamFunctionChunk( # type: ignore
|
||||
**tool["function"]
|
||||
)
|
||||
elif "name" in tool: # functions list
|
||||
openai_function_object = ChatCompletionToolParamFunctionChunk(**tool) # type: ignore
|
||||
|
||||
# check if grounding
|
||||
if tool.get("googleSearchRetrieval", None) is not None:
|
||||
googleSearchRetrieval = tool["googleSearchRetrieval"]
|
||||
elif openai_function_object is not None:
|
||||
gtool_func_declaration = FunctionDeclaration(
|
||||
name=openai_function_object["name"],
|
||||
description=openai_function_object.get("description", ""),
|
||||
parameters=openai_function_object.get("parameters", {}),
|
||||
)
|
||||
gtool_func_declarations.append(gtool_func_declaration)
|
||||
else:
|
||||
# assume it's a provider-specific param
|
||||
verbose_logger.warning(
|
||||
"Invalid tool={}. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request."
|
||||
)
|
||||
|
||||
_tools = Tools(
|
||||
function_declarations=gtool_func_declarations,
|
||||
)
|
||||
if googleSearchRetrieval is not None:
|
||||
_tools["googleSearchRetrieval"] = googleSearchRetrieval
|
||||
return [_tools]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -582,33 +646,11 @@ class VertexGeminiConfig:
|
|||
optional_params["frequency_penalty"] = value
|
||||
if param == "presence_penalty":
|
||||
optional_params["presence_penalty"] = value
|
||||
if param == "tools" and isinstance(value, list):
|
||||
gtool_func_declarations = []
|
||||
googleSearchRetrieval: Optional[dict] = None
|
||||
provider_specific_tools: List[dict] = []
|
||||
for tool in value:
|
||||
# check if grounding
|
||||
try:
|
||||
gtool_func_declaration = FunctionDeclaration(
|
||||
name=tool["function"]["name"],
|
||||
description=tool["function"].get("description", ""),
|
||||
parameters=tool["function"].get("parameters", {}),
|
||||
)
|
||||
gtool_func_declarations.append(gtool_func_declaration)
|
||||
except KeyError:
|
||||
if tool.get("googleSearchRetrieval", None) is not None:
|
||||
googleSearchRetrieval = tool["googleSearchRetrieval"]
|
||||
else:
|
||||
# assume it's a provider-specific param
|
||||
verbose_logger.warning(
|
||||
"Got KeyError parsing tool={}. Assuming it's a provider-specific param. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request."
|
||||
)
|
||||
_tools = Tools(
|
||||
function_declarations=gtool_func_declarations,
|
||||
if (param == "tools" or param == "functions") and isinstance(value, list):
|
||||
optional_params["tools"] = self._map_function(value=value)
|
||||
optional_params["litellm_param_is_function_call"] = (
|
||||
True if param == "functions" else False
|
||||
)
|
||||
if googleSearchRetrieval is not None:
|
||||
_tools["googleSearchRetrieval"] = googleSearchRetrieval
|
||||
optional_params["tools"] = [_tools] + provider_specific_tools
|
||||
if param == "tool_choice" and (
|
||||
isinstance(value, str) or isinstance(value, dict)
|
||||
):
|
||||
|
@ -780,6 +822,7 @@ class VertexLLM(BaseLLM):
|
|||
model_response: ModelResponse,
|
||||
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
messages: List,
|
||||
|
@ -796,7 +839,6 @@ class VertexLLM(BaseLLM):
|
|||
)
|
||||
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = GenerateContentResponseBody(**response.json()) # type: ignore
|
||||
|
@ -904,6 +946,7 @@ class VertexLLM(BaseLLM):
|
|||
chat_completion_message = {"role": "assistant"}
|
||||
content_str = ""
|
||||
tools: List[ChatCompletionToolCallChunk] = []
|
||||
functions: Optional[ChatCompletionToolCallFunctionChunk] = None
|
||||
for idx, candidate in enumerate(completion_response["candidates"]):
|
||||
if "content" not in candidate:
|
||||
continue
|
||||
|
@ -926,18 +969,25 @@ class VertexLLM(BaseLLM):
|
|||
candidate["content"]["parts"][0]["functionCall"]["args"]
|
||||
),
|
||||
)
|
||||
_tool_response_chunk = ChatCompletionToolCallChunk(
|
||||
id=f"call_{str(uuid.uuid4())}",
|
||||
type="function",
|
||||
function=_function_chunk,
|
||||
index=candidate.get("index", idx),
|
||||
)
|
||||
tools.append(_tool_response_chunk)
|
||||
if litellm_params.get("litellm_param_is_function_call") is True:
|
||||
functions = _function_chunk
|
||||
else:
|
||||
_tool_response_chunk = ChatCompletionToolCallChunk(
|
||||
id=f"call_{str(uuid.uuid4())}",
|
||||
type="function",
|
||||
function=_function_chunk,
|
||||
index=candidate.get("index", idx),
|
||||
)
|
||||
tools.append(_tool_response_chunk)
|
||||
|
||||
chat_completion_message["content"] = (
|
||||
content_str if len(content_str) > 0 else None
|
||||
)
|
||||
chat_completion_message["tool_calls"] = tools
|
||||
if len(tools) > 0:
|
||||
chat_completion_message["tool_calls"] = tools
|
||||
|
||||
if functions is not None:
|
||||
chat_completion_message["function_call"] = functions
|
||||
|
||||
choice = litellm.Choices(
|
||||
finish_reason=candidate.get("finishReason", "stop"),
|
||||
|
@ -1235,7 +1285,7 @@ class VertexLLM(BaseLLM):
|
|||
logging_obj,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
litellm_params: dict,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
|
@ -1269,6 +1319,7 @@ class VertexLLM(BaseLLM):
|
|||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
|
@ -1290,7 +1341,7 @@ class VertexLLM(BaseLLM):
|
|||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
gemini_api_key: Optional[str],
|
||||
litellm_params=None,
|
||||
litellm_params: dict,
|
||||
logger_fn=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
|
@ -1302,7 +1353,6 @@ class VertexLLM(BaseLLM):
|
|||
optional_params=optional_params
|
||||
)
|
||||
|
||||
print_verbose("Incoming Vertex Args - {}".format(locals()))
|
||||
auth_header, url = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=gemini_api_key,
|
||||
|
@ -1314,7 +1364,6 @@ class VertexLLM(BaseLLM):
|
|||
api_base=api_base,
|
||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||
)
|
||||
print_verbose("Updated URL - {}".format(url))
|
||||
|
||||
## TRANSFORMATION ##
|
||||
try:
|
||||
|
@ -1358,6 +1407,18 @@ class VertexLLM(BaseLLM):
|
|||
)
|
||||
optional_params.pop("response_schema")
|
||||
|
||||
# Check for any 'litellm_param_*' set during optional param mapping
|
||||
|
||||
remove_keys = []
|
||||
for k, v in optional_params.items():
|
||||
if k.startswith("litellm_param_"):
|
||||
litellm_params.update({k: v})
|
||||
remove_keys.append(k)
|
||||
|
||||
optional_params = {
|
||||
k: v for k, v in optional_params.items() if k not in remove_keys
|
||||
}
|
||||
|
||||
try:
|
||||
content = _gemini_convert_messages_with_history(messages=messages)
|
||||
tools: Optional[Tools] = optional_params.pop("tools", None)
|
||||
|
@ -1491,6 +1552,7 @@ class VertexLLM(BaseLLM):
|
|||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
api_key="",
|
||||
data=data, # type: ignore
|
||||
messages=messages,
|
||||
|
|
|
@ -2009,7 +2009,7 @@ def completion(
|
|||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=new_params,
|
||||
litellm_params=litellm_params,
|
||||
litellm_params=litellm_params, # type: ignore
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
vertex_location=vertex_ai_location,
|
||||
|
@ -2096,7 +2096,7 @@ def completion(
|
|||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=new_params,
|
||||
litellm_params=litellm_params,
|
||||
litellm_params=litellm_params, # type: ignore
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
vertex_location=vertex_ai_location,
|
||||
|
|
|
@ -2691,8 +2691,61 @@ def test_completion_hf_model_no_provider():
|
|||
# test_completion_hf_model_no_provider()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="anyscale stopped serving public api endpoints")
|
||||
def test_completion_anyscale_with_functions():
|
||||
def gemini_mock_post(*args, **kwargs):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json = MagicMock(
|
||||
return_value={
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"functionCall": {
|
||||
"name": "get_current_weather",
|
||||
"args": {"location": "Boston, MA"},
|
||||
}
|
||||
}
|
||||
],
|
||||
"role": "model",
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
"index": 0,
|
||||
"safetyRatings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"probability": "NEGLIGIBLE",
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 86,
|
||||
"candidatesTokenCount": 19,
|
||||
"totalTokenCount": 105,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return mock_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_functions_param():
|
||||
litellm.set_verbose = True
|
||||
function1 = [
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
|
@ -2711,18 +2764,33 @@ def test_completion_anyscale_with_functions():
|
|||
}
|
||||
]
|
||||
try:
|
||||
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
||||
response = completion(
|
||||
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1",
|
||||
messages=messages,
|
||||
functions=function1,
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
cost = litellm.completion_cost(completion_response=response)
|
||||
print("cost to make anyscale completion=", cost)
|
||||
assert cost > 0.0
|
||||
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
||||
|
||||
client = AsyncHTTPHandler(concurrent_limit=1)
|
||||
|
||||
with patch.object(client, "post", side_effect=gemini_mock_post) as mock_client:
|
||||
response: litellm.ModelResponse = await litellm.acompletion(
|
||||
model="gemini/gemini-1.5-pro",
|
||||
messages=messages,
|
||||
functions=function1,
|
||||
client=client,
|
||||
)
|
||||
print(response)
|
||||
# Add any assertions here to check the response
|
||||
mock_client.assert_called()
|
||||
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
|
||||
assert "tools" in mock_client.call_args.kwargs["json"]
|
||||
assert (
|
||||
"litellm_param_is_function_call"
|
||||
not in mock_client.call_args.kwargs["json"]
|
||||
)
|
||||
assert (
|
||||
"litellm_param_is_function_call"
|
||||
not in mock_client.call_args.kwargs["json"]["generationConfig"]
|
||||
)
|
||||
assert response.choices[0].message.function_call is not None
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
|
|
@ -755,27 +755,40 @@ async def test_completion_gemini_stream(sync_mode):
|
|||
try:
|
||||
litellm.set_verbose = True
|
||||
print("Streaming gemini response")
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
function1 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Who was Alexander?",
|
||||
},
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
]
|
||||
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
||||
print("testing gemini streaming")
|
||||
complete_response = ""
|
||||
# Add any assertions here to check the response
|
||||
non_empty_chunks = 0
|
||||
|
||||
chunks = []
|
||||
if sync_mode:
|
||||
response = completion(
|
||||
model="gemini/gemini-1.5-flash",
|
||||
messages=messages,
|
||||
stream=True,
|
||||
functions=function1,
|
||||
)
|
||||
|
||||
for idx, chunk in enumerate(response):
|
||||
print(chunk)
|
||||
chunks.append(chunk)
|
||||
# print(chunk.choices[0].delta)
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
if finished:
|
||||
|
@ -787,11 +800,13 @@ async def test_completion_gemini_stream(sync_mode):
|
|||
model="gemini/gemini-1.5-flash",
|
||||
messages=messages,
|
||||
stream=True,
|
||||
functions=function1,
|
||||
)
|
||||
|
||||
idx = 0
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
chunks.append(chunk)
|
||||
# print(chunk.choices[0].delta)
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
if finished:
|
||||
|
@ -800,10 +815,17 @@ async def test_completion_gemini_stream(sync_mode):
|
|||
complete_response += chunk
|
||||
idx += 1
|
||||
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
# if complete_response.strip() == "":
|
||||
# raise Exception("Empty response received")
|
||||
print(f"completion_response: {complete_response}")
|
||||
assert non_empty_chunks > 1
|
||||
|
||||
complete_response = litellm.stream_chunk_builder(
|
||||
chunks=chunks, messages=messages
|
||||
)
|
||||
|
||||
assert complete_response.choices[0].message.function_call is not None
|
||||
|
||||
# assert non_empty_chunks > 1
|
||||
except litellm.InternalServerError as e:
|
||||
pass
|
||||
except litellm.RateLimitError as e:
|
||||
|
|
|
@ -449,6 +449,7 @@ class ChatCompletionResponseMessage(TypedDict, total=False):
|
|||
content: Optional[str]
|
||||
tool_calls: List[ChatCompletionToolCallChunk]
|
||||
role: Literal["assistant"]
|
||||
function_call: ChatCompletionToolCallFunctionChunk
|
||||
|
||||
|
||||
class ChatCompletionUsageBlock(TypedDict):
|
||||
|
|
|
@ -90,7 +90,7 @@ class Schema(TypedDict, total=False):
|
|||
class FunctionDeclaration(TypedDict, total=False):
|
||||
name: Required[str]
|
||||
description: str
|
||||
parameters: Schema
|
||||
parameters: Union[Schema, dict]
|
||||
response: Schema
|
||||
|
||||
|
||||
|
|
|
@ -8779,6 +8779,7 @@ class CustomStreamWrapper:
|
|||
self.chunks: List = (
|
||||
[]
|
||||
) # keep track of the returned chunks - used for calculating the input/output tokens for stream options
|
||||
self.is_function_call = self.check_is_function_call(logging_obj=logging_obj)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
@ -8786,6 +8787,19 @@ class CustomStreamWrapper:
|
|||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
def check_is_function_call(self, logging_obj) -> bool:
|
||||
if hasattr(logging_obj, "optional_params") and isinstance(
|
||||
logging_obj.optional_params, dict
|
||||
):
|
||||
if (
|
||||
"litellm_param_is_function_call" in logging_obj.optional_params
|
||||
and logging_obj.optional_params["litellm_param_is_function_call"]
|
||||
is True
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def process_chunk(self, chunk: str):
|
||||
"""
|
||||
NLP Cloud streaming returns the entire response, for each chunk. Process this, to only return the delta.
|
||||
|
@ -10283,6 +10297,12 @@ class CustomStreamWrapper:
|
|||
|
||||
## CHECK FOR TOOL USE
|
||||
if "tool_calls" in completion_obj and len(completion_obj["tool_calls"]) > 0:
|
||||
if self.is_function_call is True: # user passed in 'functions' param
|
||||
completion_obj["function_call"] = completion_obj["tool_calls"][0][
|
||||
"function"
|
||||
]
|
||||
completion_obj["tool_calls"] = None
|
||||
|
||||
self.tool_call = True
|
||||
|
||||
## RETURN ARG
|
||||
|
@ -10294,8 +10314,13 @@ class CustomStreamWrapper:
|
|||
)
|
||||
or (
|
||||
"tool_calls" in completion_obj
|
||||
and completion_obj["tool_calls"] is not None
|
||||
and len(completion_obj["tool_calls"]) > 0
|
||||
)
|
||||
or (
|
||||
"function_call" in completion_obj
|
||||
and completion_obj["function_call"] is not None
|
||||
)
|
||||
): # cannot set content of an OpenAI Object to be an empty string
|
||||
self.safety_checker()
|
||||
hold, model_response_str = self.check_special_tokens(
|
||||
|
@ -10355,6 +10380,7 @@ class CustomStreamWrapper:
|
|||
if self.sent_first_chunk is False:
|
||||
completion_obj["role"] = "assistant"
|
||||
self.sent_first_chunk = True
|
||||
|
||||
model_response.choices[0].delta = Delta(**completion_obj)
|
||||
if completion_obj.get("index") is not None:
|
||||
model_response.choices[0].index = completion_obj.get(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue