(feat) support parallel function calling

This commit is contained in:
ishaan-jaff 2023-11-17 15:51:00 -08:00
parent 32f22adf8b
commit 88200432b0

View file

@ -122,14 +122,30 @@ class FunctionCall(OpenAIObject):
arguments: str arguments: str
name: str name: str
class Function(OpenAIObject):
arguments: str
name: str
class ChatCompletionMessageToolCall(OpenAIObject):
id: str
function: Function
type: str
class Message(OpenAIObject): class Message(OpenAIObject):
def __init__(self, content="default", role="assistant", logprobs=None, function_call=None, **params): def __init__(self, content="default", role="assistant", logprobs=None, function_call=None, tool_calls=None, **params):
super(Message, self).__init__(**params) super(Message, self).__init__(**params)
self.content = content self.content = content
self.role = role self.role = role
self._logprobs = logprobs if function_call is not None:
if function_call:
self.function_call = FunctionCall(**function_call) self.function_call = FunctionCall(**function_call)
if tool_calls is not None:
self.tool_calls = []
for tool_call in tool_calls:
self.tool_calls.append(
ChatCompletionMessageToolCall(**tool_call)
)
if logprobs is not None:
self._logprobs = logprobs
def get(self, key, default=None): def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist # Custom .get() method to access attributes with a default value if the attribute doesn't exist
@ -519,7 +535,7 @@ class Logging:
curl_command += "curl -X POST \\\n" curl_command += "curl -X POST \\\n"
curl_command += f"{api_base} \\\n" curl_command += f"{api_base} \\\n"
curl_command += f"{formatted_headers} \\\n" if formatted_headers.strip() != "" else "" curl_command += f"{formatted_headers} \\\n" if formatted_headers.strip() != "" else ""
curl_command += f"-d '{json.dumps(data)}'\n" curl_command += f"-d '{str(data)}'\n"
if api_base == "": if api_base == "":
curl_command = self.model_call_details curl_command = self.model_call_details
@ -3060,7 +3076,12 @@ def convert_to_model_response_object(response_object: Optional[dict]=None, model
raise Exception("Error in response object format") raise Exception("Error in response object format")
choice_list=[] choice_list=[]
for idx, choice in enumerate(response_object["choices"]): for idx, choice in enumerate(response_object["choices"]):
message = Message(content=choice["message"].get("content", None), role=choice["message"]["role"], function_call=choice["message"].get("function_call", None)) message = Message(
content=choice["message"].get("content", None),
role=choice["message"]["role"],
function_call=choice["message"].get("function_call", None),
tool_calls=choice["message"].get("tool_calls", None)
)
finish_reason = choice.get("finish_reason", None) finish_reason = choice.get("finish_reason", None)
if finish_reason == None: if finish_reason == None:
# gpt-4 vision can return 'finish_reason' or 'finish_details' # gpt-4 vision can return 'finish_reason' or 'finish_details'
@ -3075,6 +3096,9 @@ def convert_to_model_response_object(response_object: Optional[dict]=None, model
if "id" in response_object: if "id" in response_object:
model_response_object.id = response_object["id"] model_response_object.id = response_object["id"]
if "system_fingerprint" in response_object:
model_response_object.system_fingerprint = response_object["system_fingerprint"]
if "model" in response_object: if "model" in response_object:
model_response_object.model = response_object["model"] model_response_object.model = response_object["model"]
return model_response_object return model_response_object