Merge branch 'BerriAI:main' into main

This commit is contained in:
Vince Loewe 2024-02-28 22:18:14 -08:00 committed by GitHub
commit f98619e6f2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
68 changed files with 2676 additions and 1126 deletions

View file

@ -205,18 +205,18 @@ def map_finish_reason(
class FunctionCall(OpenAIObject):
arguments: str
name: str
name: Optional[str] = None
class Function(OpenAIObject):
arguments: str
name: str
name: Optional[str] = None
class ChatCompletionDeltaToolCall(OpenAIObject):
id: str
id: Optional[str] = None
function: Function
type: str
type: Optional[str] = None
index: int
@ -275,13 +275,19 @@ class Delta(OpenAIObject):
super(Delta, self).__init__(**params)
self.content = content
self.role = role
self.function_call = function_call
if tool_calls is not None and isinstance(tool_calls, dict):
if function_call is not None and isinstance(function_call, dict):
self.function_call = FunctionCall(**function_call)
else:
self.function_call = function_call
if tool_calls is not None and isinstance(tool_calls, list):
self.tool_calls = []
for tool_call in tool_calls:
if tool_call.get("index", None) is None:
tool_call["index"] = 0
self.tool_calls.append(ChatCompletionDeltaToolCall(**tool_call))
if isinstance(tool_call, dict):
if tool_call.get("index", None) is None:
tool_call["index"] = 0
self.tool_calls.append(ChatCompletionDeltaToolCall(**tool_call))
elif isinstance(tool_call, ChatCompletionDeltaToolCall):
self.tool_calls.append(tool_call)
else:
self.tool_calls = tool_calls
@ -1636,7 +1642,7 @@ class Logging:
verbose_logger.debug(
"Async success callbacks: Got a complete streaming response"
)
self.model_call_details["complete_streaming_response"] = (
self.model_call_details["async_complete_streaming_response"] = (
complete_streaming_response
)
try:
@ -1684,28 +1690,31 @@ class Logging:
print_verbose("async success_callback: reaches cache for logging!")
kwargs = self.model_call_details
if self.stream:
if "complete_streaming_response" not in kwargs:
if "async_complete_streaming_response" not in kwargs:
print_verbose(
f"async success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n"
f"async success_callback: reaches cache for logging, there is no async_complete_streaming_response. Kwargs={kwargs}\n\n"
)
pass
else:
print_verbose(
"async success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache"
"async success_callback: reaches cache for logging, there is a async_complete_streaming_response. Adding to cache"
)
result = kwargs["complete_streaming_response"]
result = kwargs["async_complete_streaming_response"]
# only add to cache once we have a complete streaming response
litellm.cache.add_cache(result, **kwargs)
if isinstance(callback, CustomLogger): # custom logger class
print_verbose(
f"Async success callbacks: {callback}; self.stream: {self.stream}; complete_streaming_response: {self.model_call_details.get('complete_streaming_response', None)}"
f"Running Async success callback: {callback}; self.stream: {self.stream}; async_complete_streaming_response: {self.model_call_details.get('async_complete_streaming_response', None)} result={result}"
)
if self.stream == True:
if "complete_streaming_response" in self.model_call_details:
if (
"async_complete_streaming_response"
in self.model_call_details
):
await callback.async_log_success_event(
kwargs=self.model_call_details,
response_obj=self.model_call_details[
"complete_streaming_response"
"async_complete_streaming_response"
],
start_time=start_time,
end_time=end_time,
@ -1726,14 +1735,18 @@ class Logging:
)
if callable(callback): # custom logger functions
print_verbose(
f"Making async function logging call - {self.model_call_details}"
f"Making async function logging call for {callback}, result={result} - {self.model_call_details}"
)
if self.stream:
if "complete_streaming_response" in self.model_call_details:
if (
"async_complete_streaming_response"
in self.model_call_details
):
await customLogger.async_log_event(
kwargs=self.model_call_details,
response_obj=self.model_call_details[
"complete_streaming_response"
"async_complete_streaming_response"
],
start_time=start_time,
end_time=end_time,
@ -1754,14 +1767,17 @@ class Logging:
if dynamoLogger is None:
dynamoLogger = DyanmoDBLogger()
if self.stream:
if "complete_streaming_response" in self.model_call_details:
if (
"async_complete_streaming_response"
in self.model_call_details
):
print_verbose(
"DynamoDB Logger: Got Stream Event - Completed Stream Response"
)
await dynamoLogger._async_log_event(
kwargs=self.model_call_details,
response_obj=self.model_call_details[
"complete_streaming_response"
"async_complete_streaming_response"
],
start_time=start_time,
end_time=end_time,
@ -3715,6 +3731,54 @@ def completion_cost(
raise e
def supports_function_calling(model: str):
"""
Check if the given model supports function calling and return a boolean value.
Parameters:
model (str): The model name to be checked.
Returns:
bool: True if the model supports function calling, False otherwise.
Raises:
Exception: If the given model is not found in model_prices_and_context_window.json.
"""
if model in litellm.model_cost:
model_info = litellm.model_cost[model]
if model_info.get("supports_function_calling", False):
return True
return False
else:
raise Exception(
f"Model not in model_prices_and_context_window.json. You passed model={model}."
)
def supports_parallel_function_calling(model: str):
"""
Check if the given model supports parallel function calling and return True if it does, False otherwise.
Parameters:
model (str): The model to check for support of parallel function calling.
Returns:
bool: True if the model supports parallel function calling, False otherwise.
Raises:
Exception: If the model is not found in the model_cost dictionary.
"""
if model in litellm.model_cost:
model_info = litellm.model_cost[model]
if model_info.get("supports_parallel_function_calling", False):
return True
return False
else:
raise Exception(
f"Model not in model_prices_and_context_window.json. You passed model={model}."
)
####### HELPER FUNCTIONS ################
def register_model(model_cost: Union[str, dict]):
"""
@ -4043,6 +4107,7 @@ def get_optional_params(
and custom_llm_provider != "vertex_ai"
and custom_llm_provider != "anyscale"
and custom_llm_provider != "together_ai"
and custom_llm_provider != "mistral"
):
if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
# ollama actually supports json output
@ -4713,7 +4778,14 @@ def get_optional_params(
if max_tokens:
optional_params["max_tokens"] = max_tokens
elif custom_llm_provider == "mistral":
supported_params = ["temperature", "top_p", "stream", "max_tokens"]
supported_params = [
"temperature",
"top_p",
"stream",
"max_tokens",
"tools",
"tool_choice",
]
_check_valid_arg(supported_params=supported_params)
if temperature is not None:
optional_params["temperature"] = temperature
@ -4723,6 +4795,10 @@ def get_optional_params(
optional_params["stream"] = stream
if max_tokens is not None:
optional_params["max_tokens"] = max_tokens
if tools is not None:
optional_params["tools"] = tools
if tool_choice is not None:
optional_params["tool_choice"] = tool_choice
# check safe_mode, random_seed: https://docs.mistral.ai/api/#operation/createChatCompletion
safe_mode = passed_params.pop("safe_mode", None)
@ -6947,7 +7023,7 @@ def exception_type(
if "500 An internal error has occurred." in error_str:
exception_mapping_worked = True
raise APIError(
status_code=original_exception.status_code,
status_code=getattr(original_exception, "status_code", 500),
message=f"PalmException - {original_exception.message}",
llm_provider="palm",
model=model,
@ -8730,7 +8806,7 @@ class CustomStreamWrapper:
or original_chunk.choices[0].delta.tool_calls is not None
):
try:
delta = dict(original_chunk.choices[0].delta)
delta = original_chunk.choices[0].delta
model_response.system_fingerprint = (
original_chunk.system_fingerprint
)
@ -8765,7 +8841,9 @@ class CustomStreamWrapper:
is None
):
t.function.arguments = ""
model_response.choices[0].delta = Delta(**delta)
_json_delta = delta.model_dump()
print_verbose(f"_json_delta: {_json_delta}")
model_response.choices[0].delta = Delta(**_json_delta)
except Exception as e:
traceback.print_exc()
model_response.choices[0].delta = Delta()