mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge branch 'BerriAI:main' into main
This commit is contained in:
commit
f98619e6f2
68 changed files with 2676 additions and 1126 deletions
130
litellm/utils.py
130
litellm/utils.py
|
@ -205,18 +205,18 @@ def map_finish_reason(
|
|||
|
||||
class FunctionCall(OpenAIObject):
|
||||
arguments: str
|
||||
name: str
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class Function(OpenAIObject):
|
||||
arguments: str
|
||||
name: str
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionDeltaToolCall(OpenAIObject):
|
||||
id: str
|
||||
id: Optional[str] = None
|
||||
function: Function
|
||||
type: str
|
||||
type: Optional[str] = None
|
||||
index: int
|
||||
|
||||
|
||||
|
@ -275,13 +275,19 @@ class Delta(OpenAIObject):
|
|||
super(Delta, self).__init__(**params)
|
||||
self.content = content
|
||||
self.role = role
|
||||
self.function_call = function_call
|
||||
if tool_calls is not None and isinstance(tool_calls, dict):
|
||||
if function_call is not None and isinstance(function_call, dict):
|
||||
self.function_call = FunctionCall(**function_call)
|
||||
else:
|
||||
self.function_call = function_call
|
||||
if tool_calls is not None and isinstance(tool_calls, list):
|
||||
self.tool_calls = []
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.get("index", None) is None:
|
||||
tool_call["index"] = 0
|
||||
self.tool_calls.append(ChatCompletionDeltaToolCall(**tool_call))
|
||||
if isinstance(tool_call, dict):
|
||||
if tool_call.get("index", None) is None:
|
||||
tool_call["index"] = 0
|
||||
self.tool_calls.append(ChatCompletionDeltaToolCall(**tool_call))
|
||||
elif isinstance(tool_call, ChatCompletionDeltaToolCall):
|
||||
self.tool_calls.append(tool_call)
|
||||
else:
|
||||
self.tool_calls = tool_calls
|
||||
|
||||
|
@ -1636,7 +1642,7 @@ class Logging:
|
|||
verbose_logger.debug(
|
||||
"Async success callbacks: Got a complete streaming response"
|
||||
)
|
||||
self.model_call_details["complete_streaming_response"] = (
|
||||
self.model_call_details["async_complete_streaming_response"] = (
|
||||
complete_streaming_response
|
||||
)
|
||||
try:
|
||||
|
@ -1684,28 +1690,31 @@ class Logging:
|
|||
print_verbose("async success_callback: reaches cache for logging!")
|
||||
kwargs = self.model_call_details
|
||||
if self.stream:
|
||||
if "complete_streaming_response" not in kwargs:
|
||||
if "async_complete_streaming_response" not in kwargs:
|
||||
print_verbose(
|
||||
f"async success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n"
|
||||
f"async success_callback: reaches cache for logging, there is no async_complete_streaming_response. Kwargs={kwargs}\n\n"
|
||||
)
|
||||
pass
|
||||
else:
|
||||
print_verbose(
|
||||
"async success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache"
|
||||
"async success_callback: reaches cache for logging, there is a async_complete_streaming_response. Adding to cache"
|
||||
)
|
||||
result = kwargs["complete_streaming_response"]
|
||||
result = kwargs["async_complete_streaming_response"]
|
||||
# only add to cache once we have a complete streaming response
|
||||
litellm.cache.add_cache(result, **kwargs)
|
||||
if isinstance(callback, CustomLogger): # custom logger class
|
||||
print_verbose(
|
||||
f"Async success callbacks: {callback}; self.stream: {self.stream}; complete_streaming_response: {self.model_call_details.get('complete_streaming_response', None)}"
|
||||
f"Running Async success callback: {callback}; self.stream: {self.stream}; async_complete_streaming_response: {self.model_call_details.get('async_complete_streaming_response', None)} result={result}"
|
||||
)
|
||||
if self.stream == True:
|
||||
if "complete_streaming_response" in self.model_call_details:
|
||||
if (
|
||||
"async_complete_streaming_response"
|
||||
in self.model_call_details
|
||||
):
|
||||
await callback.async_log_success_event(
|
||||
kwargs=self.model_call_details,
|
||||
response_obj=self.model_call_details[
|
||||
"complete_streaming_response"
|
||||
"async_complete_streaming_response"
|
||||
],
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
|
@ -1726,14 +1735,18 @@ class Logging:
|
|||
)
|
||||
if callable(callback): # custom logger functions
|
||||
print_verbose(
|
||||
f"Making async function logging call - {self.model_call_details}"
|
||||
f"Making async function logging call for {callback}, result={result} - {self.model_call_details}"
|
||||
)
|
||||
if self.stream:
|
||||
if "complete_streaming_response" in self.model_call_details:
|
||||
if (
|
||||
"async_complete_streaming_response"
|
||||
in self.model_call_details
|
||||
):
|
||||
|
||||
await customLogger.async_log_event(
|
||||
kwargs=self.model_call_details,
|
||||
response_obj=self.model_call_details[
|
||||
"complete_streaming_response"
|
||||
"async_complete_streaming_response"
|
||||
],
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
|
@ -1754,14 +1767,17 @@ class Logging:
|
|||
if dynamoLogger is None:
|
||||
dynamoLogger = DyanmoDBLogger()
|
||||
if self.stream:
|
||||
if "complete_streaming_response" in self.model_call_details:
|
||||
if (
|
||||
"async_complete_streaming_response"
|
||||
in self.model_call_details
|
||||
):
|
||||
print_verbose(
|
||||
"DynamoDB Logger: Got Stream Event - Completed Stream Response"
|
||||
)
|
||||
await dynamoLogger._async_log_event(
|
||||
kwargs=self.model_call_details,
|
||||
response_obj=self.model_call_details[
|
||||
"complete_streaming_response"
|
||||
"async_complete_streaming_response"
|
||||
],
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
|
@ -3715,6 +3731,54 @@ def completion_cost(
|
|||
raise e
|
||||
|
||||
|
||||
def supports_function_calling(model: str):
|
||||
"""
|
||||
Check if the given model supports function calling and return a boolean value.
|
||||
|
||||
Parameters:
|
||||
model (str): The model name to be checked.
|
||||
|
||||
Returns:
|
||||
bool: True if the model supports function calling, False otherwise.
|
||||
|
||||
Raises:
|
||||
Exception: If the given model is not found in model_prices_and_context_window.json.
|
||||
"""
|
||||
if model in litellm.model_cost:
|
||||
model_info = litellm.model_cost[model]
|
||||
if model_info.get("supports_function_calling", False):
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
raise Exception(
|
||||
f"Model not in model_prices_and_context_window.json. You passed model={model}."
|
||||
)
|
||||
|
||||
|
||||
def supports_parallel_function_calling(model: str):
|
||||
"""
|
||||
Check if the given model supports parallel function calling and return True if it does, False otherwise.
|
||||
|
||||
Parameters:
|
||||
model (str): The model to check for support of parallel function calling.
|
||||
|
||||
Returns:
|
||||
bool: True if the model supports parallel function calling, False otherwise.
|
||||
|
||||
Raises:
|
||||
Exception: If the model is not found in the model_cost dictionary.
|
||||
"""
|
||||
if model in litellm.model_cost:
|
||||
model_info = litellm.model_cost[model]
|
||||
if model_info.get("supports_parallel_function_calling", False):
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
raise Exception(
|
||||
f"Model not in model_prices_and_context_window.json. You passed model={model}."
|
||||
)
|
||||
|
||||
|
||||
####### HELPER FUNCTIONS ################
|
||||
def register_model(model_cost: Union[str, dict]):
|
||||
"""
|
||||
|
@ -4043,6 +4107,7 @@ def get_optional_params(
|
|||
and custom_llm_provider != "vertex_ai"
|
||||
and custom_llm_provider != "anyscale"
|
||||
and custom_llm_provider != "together_ai"
|
||||
and custom_llm_provider != "mistral"
|
||||
):
|
||||
if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
|
||||
# ollama actually supports json output
|
||||
|
@ -4713,7 +4778,14 @@ def get_optional_params(
|
|||
if max_tokens:
|
||||
optional_params["max_tokens"] = max_tokens
|
||||
elif custom_llm_provider == "mistral":
|
||||
supported_params = ["temperature", "top_p", "stream", "max_tokens"]
|
||||
supported_params = [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stream",
|
||||
"max_tokens",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
]
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
if temperature is not None:
|
||||
optional_params["temperature"] = temperature
|
||||
|
@ -4723,6 +4795,10 @@ def get_optional_params(
|
|||
optional_params["stream"] = stream
|
||||
if max_tokens is not None:
|
||||
optional_params["max_tokens"] = max_tokens
|
||||
if tools is not None:
|
||||
optional_params["tools"] = tools
|
||||
if tool_choice is not None:
|
||||
optional_params["tool_choice"] = tool_choice
|
||||
|
||||
# check safe_mode, random_seed: https://docs.mistral.ai/api/#operation/createChatCompletion
|
||||
safe_mode = passed_params.pop("safe_mode", None)
|
||||
|
@ -6947,7 +7023,7 @@ def exception_type(
|
|||
if "500 An internal error has occurred." in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise APIError(
|
||||
status_code=original_exception.status_code,
|
||||
status_code=getattr(original_exception, "status_code", 500),
|
||||
message=f"PalmException - {original_exception.message}",
|
||||
llm_provider="palm",
|
||||
model=model,
|
||||
|
@ -8730,7 +8806,7 @@ class CustomStreamWrapper:
|
|||
or original_chunk.choices[0].delta.tool_calls is not None
|
||||
):
|
||||
try:
|
||||
delta = dict(original_chunk.choices[0].delta)
|
||||
delta = original_chunk.choices[0].delta
|
||||
model_response.system_fingerprint = (
|
||||
original_chunk.system_fingerprint
|
||||
)
|
||||
|
@ -8765,7 +8841,9 @@ class CustomStreamWrapper:
|
|||
is None
|
||||
):
|
||||
t.function.arguments = ""
|
||||
model_response.choices[0].delta = Delta(**delta)
|
||||
_json_delta = delta.model_dump()
|
||||
print_verbose(f"_json_delta: {_json_delta}")
|
||||
model_response.choices[0].delta = Delta(**_json_delta)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
model_response.choices[0].delta = Delta()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue