forked from phoenix/litellm-mirror
fix(stream_chunk_builder): adding support for tool calling in completion counting
This commit is contained in:
parent
40d9e8ab23
commit
4cdd930fa2
3 changed files with 109 additions and 10 deletions
|
@ -2109,7 +2109,36 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None):
|
||||||
content_list = []
|
content_list = []
|
||||||
combined_content = ""
|
combined_content = ""
|
||||||
|
|
||||||
if "function_call" in chunks[0]["choices"][0]["delta"] and chunks[0]["choices"][0]["delta"]["function_call"] is not None:
|
if "tool_calls" in chunks[0]["choices"][0]["delta"] and chunks[0]["choices"][0]["delta"]["tool_calls"] is not None:
|
||||||
|
argument_list = []
|
||||||
|
delta = chunks[0]["choices"][0]["delta"]
|
||||||
|
message = response["choices"][0]["message"]
|
||||||
|
message["tool_calls"] = []
|
||||||
|
id = None
|
||||||
|
name = None
|
||||||
|
type = None
|
||||||
|
for chunk in chunks:
|
||||||
|
choices = chunk["choices"]
|
||||||
|
for choice in choices:
|
||||||
|
delta = choice.get("delta", {})
|
||||||
|
tool_calls = delta.get("tool_calls", "")
|
||||||
|
# Check if a tool call is present
|
||||||
|
if tool_calls and tool_calls[0].function is not None:
|
||||||
|
if tool_calls[0].id:
|
||||||
|
id = tool_calls[0].id
|
||||||
|
if tool_calls[0].function.arguments:
|
||||||
|
# Now, tool_calls is expected to be a dictionary
|
||||||
|
arguments = tool_calls[0].function.arguments
|
||||||
|
argument_list.append(arguments)
|
||||||
|
if tool_calls[0].function.name:
|
||||||
|
name = tool_calls[0].function.name
|
||||||
|
if tool_calls[0].type:
|
||||||
|
type = tool_calls[0].type
|
||||||
|
|
||||||
|
combined_arguments = "".join(argument_list)
|
||||||
|
response["choices"][0]["message"]["content"] = None
|
||||||
|
response["choices"][0]["message"]["tool_calls"] = [{"id": id, "function": {"arguments": combined_arguments, "name": name}, "type": type}]
|
||||||
|
elif "function_call" in chunks[0]["choices"][0]["delta"] and chunks[0]["choices"][0]["delta"]["function_call"] is not None:
|
||||||
argument_list = []
|
argument_list = []
|
||||||
delta = chunks[0]["choices"][0]["delta"]
|
delta = chunks[0]["choices"][0]["delta"]
|
||||||
function_call = delta.get("function_call", "")
|
function_call = delta.get("function_call", "")
|
||||||
|
@ -2144,16 +2173,20 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None):
|
||||||
continue # openai v1.0.0 sets content = None for chunks
|
continue # openai v1.0.0 sets content = None for chunks
|
||||||
content_list.append(content)
|
content_list.append(content)
|
||||||
|
|
||||||
# Combine the "content" strings into a single string
|
# Combine the "content" strings into a single string || combine the 'function' strings into a single string
|
||||||
combined_content = "".join(content_list)
|
combined_content = "".join(combined_arguments)
|
||||||
|
|
||||||
# Update the "content" field within the response dictionary
|
# Update the "content" field within the response dictionary
|
||||||
response["choices"][0]["message"]["content"] = combined_content
|
response["choices"][0]["message"]["content"] = combined_content
|
||||||
|
|
||||||
|
if len(combined_content) > 0:
|
||||||
|
completion_output = combined_content
|
||||||
|
elif len(combined_arguments) > 0:
|
||||||
|
completion_output = combined_arguments
|
||||||
|
|
||||||
# # Update usage information if needed
|
# # Update usage information if needed
|
||||||
if messages:
|
if messages:
|
||||||
response["usage"]["prompt_tokens"] = token_counter(model=model, messages=messages)
|
response["usage"]["prompt_tokens"] = token_counter(model=model, messages=messages)
|
||||||
response["usage"]["completion_tokens"] = token_counter(model=model, text=combined_content)
|
response["usage"]["completion_tokens"] = token_counter(model=model, text=completion_output)
|
||||||
response["usage"]["total_tokens"] = response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
|
response["usage"]["total_tokens"] = response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
|
||||||
return convert_to_model_response_object(response_object=response, model_response_object=litellm.ModelResponse())
|
return convert_to_model_response_object(response_object=response, model_response_object=litellm.ModelResponse())
|
||||||
|
|
|
@ -7,6 +7,7 @@ sys.path.insert(
|
||||||
from litellm import completion, stream_chunk_builder
|
from litellm import completion, stream_chunk_builder
|
||||||
import litellm
|
import litellm
|
||||||
import os, dotenv
|
import os, dotenv
|
||||||
|
from openai import OpenAI
|
||||||
import pytest
|
import pytest
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
@ -30,20 +31,81 @@ function_schema = {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_stream_chunk_builder():
|
|
||||||
|
tools_schema = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# def test_stream_chunk_builder_tools():
|
||||||
|
# try:
|
||||||
|
# litellm.set_verbose = False
|
||||||
|
# response = client.chat.completions.create(
|
||||||
|
# model="gpt-3.5-turbo",
|
||||||
|
# messages=messages,
|
||||||
|
# tools=tools_schema,
|
||||||
|
# # stream=True,
|
||||||
|
# # complete_response=True # runs stream_chunk_builder under-the-hood
|
||||||
|
# )
|
||||||
|
|
||||||
|
# print(f"response: {response}")
|
||||||
|
# print(f"response usage: {response.usage}")
|
||||||
|
# except Exception as e:
|
||||||
|
# pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
# test_stream_chunk_builder_tools()
|
||||||
|
|
||||||
|
def test_stream_chunk_builder_litellm_function_call():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
response = completion(
|
response = litellm.completion(
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
functions=[function_schema],
|
functions=[function_schema],
|
||||||
stream=True,
|
# stream=True,
|
||||||
complete_response=True # runs stream_chunk_builder under-the-hood
|
# complete_response=True # runs stream_chunk_builder under-the-hood
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
print(f"response usage: {response['usage']}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
test_stream_chunk_builder()
|
# test_stream_chunk_builder_litellm_function_call()
|
||||||
|
|
||||||
|
def test_stream_chunk_builder_litellm_tool_call():
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = False
|
||||||
|
response = litellm.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
tools=tools_schema,
|
||||||
|
stream=True,
|
||||||
|
complete_response = True
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"complete response: {response}")
|
||||||
|
print(f"complete response usage: {response.usage}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
test_stream_chunk_builder_litellm_tool_call()
|
||||||
|
|
|
@ -140,12 +140,16 @@ class Message(OpenAIObject):
|
||||||
self.role = role
|
self.role = role
|
||||||
if function_call is not None:
|
if function_call is not None:
|
||||||
self.function_call = FunctionCall(**function_call)
|
self.function_call = FunctionCall(**function_call)
|
||||||
|
else:
|
||||||
|
self.function_call = None
|
||||||
if tool_calls is not None:
|
if tool_calls is not None:
|
||||||
self.tool_calls = []
|
self.tool_calls = []
|
||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
self.tool_calls.append(
|
self.tool_calls.append(
|
||||||
ChatCompletionMessageToolCall(**tool_call)
|
ChatCompletionMessageToolCall(**tool_call)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.tool_calls = None
|
||||||
if logprobs is not None:
|
if logprobs is not None:
|
||||||
self._logprobs = logprobs
|
self._logprobs = logprobs
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue