forked from phoenix/litellm-mirror
Merge branch 'BerriAI:main' into main
This commit is contained in:
commit
c8b8f93184
50 changed files with 792 additions and 382 deletions
497
litellm/utils.py
497
litellm/utils.py
|
@ -228,6 +228,24 @@ class Function(OpenAIObject):
|
|||
arguments: str
|
||||
name: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
arguments: Union[Dict, str],
|
||||
name: Optional[str] = None,
|
||||
**params,
|
||||
):
|
||||
if isinstance(arguments, Dict):
|
||||
arguments = json.dumps(arguments)
|
||||
else:
|
||||
arguments = arguments
|
||||
|
||||
name = name
|
||||
|
||||
# Build a dictionary with the structure your BaseModel expects
|
||||
data = {"arguments": arguments, "name": name, **params}
|
||||
|
||||
super(Function, self).__init__(**data)
|
||||
|
||||
|
||||
class ChatCompletionDeltaToolCall(OpenAIObject):
|
||||
id: Optional[str] = None
|
||||
|
@ -2392,210 +2410,202 @@ class Rules:
|
|||
|
||||
####### CLIENT ###################
|
||||
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
|
||||
def function_setup(
|
||||
original_function, rules_obj, start_time, *args, **kwargs
|
||||
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
||||
try:
|
||||
global callback_list, add_breadcrumb, user_logger_fn, Logging
|
||||
function_id = kwargs["id"] if "id" in kwargs else None
|
||||
if len(litellm.callbacks) > 0:
|
||||
for callback in litellm.callbacks:
|
||||
if callback not in litellm.input_callback:
|
||||
litellm.input_callback.append(callback)
|
||||
if callback not in litellm.success_callback:
|
||||
litellm.success_callback.append(callback)
|
||||
if callback not in litellm.failure_callback:
|
||||
litellm.failure_callback.append(callback)
|
||||
if callback not in litellm._async_success_callback:
|
||||
litellm._async_success_callback.append(callback)
|
||||
if callback not in litellm._async_failure_callback:
|
||||
litellm._async_failure_callback.append(callback)
|
||||
print_verbose(
|
||||
f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}"
|
||||
)
|
||||
if (
|
||||
len(litellm.input_callback) > 0
|
||||
or len(litellm.success_callback) > 0
|
||||
or len(litellm.failure_callback) > 0
|
||||
) and len(callback_list) == 0:
|
||||
callback_list = list(
|
||||
set(
|
||||
litellm.input_callback
|
||||
+ litellm.success_callback
|
||||
+ litellm.failure_callback
|
||||
)
|
||||
)
|
||||
set_callbacks(callback_list=callback_list, function_id=function_id)
|
||||
## ASYNC CALLBACKS
|
||||
if len(litellm.input_callback) > 0:
|
||||
removed_async_items = []
|
||||
for index, callback in enumerate(litellm.input_callback):
|
||||
if inspect.iscoroutinefunction(callback):
|
||||
litellm._async_input_callback.append(callback)
|
||||
removed_async_items.append(index)
|
||||
|
||||
# Pop the async items from input_callback in reverse order to avoid index issues
|
||||
for index in reversed(removed_async_items):
|
||||
litellm.input_callback.pop(index)
|
||||
|
||||
if len(litellm.success_callback) > 0:
|
||||
removed_async_items = []
|
||||
for index, callback in enumerate(litellm.success_callback):
|
||||
if inspect.iscoroutinefunction(callback):
|
||||
litellm._async_success_callback.append(callback)
|
||||
removed_async_items.append(index)
|
||||
elif callback == "dynamodb":
|
||||
# dynamo is an async callback, it's used for the proxy and needs to be async
|
||||
# we only support async dynamo db logging for acompletion/aembedding since that's used on proxy
|
||||
litellm._async_success_callback.append(callback)
|
||||
removed_async_items.append(index)
|
||||
|
||||
# Pop the async items from success_callback in reverse order to avoid index issues
|
||||
for index in reversed(removed_async_items):
|
||||
litellm.success_callback.pop(index)
|
||||
|
||||
if len(litellm.failure_callback) > 0:
|
||||
removed_async_items = []
|
||||
for index, callback in enumerate(litellm.failure_callback):
|
||||
if inspect.iscoroutinefunction(callback):
|
||||
litellm._async_failure_callback.append(callback)
|
||||
removed_async_items.append(index)
|
||||
|
||||
# Pop the async items from failure_callback in reverse order to avoid index issues
|
||||
for index in reversed(removed_async_items):
|
||||
litellm.failure_callback.pop(index)
|
||||
### DYNAMIC CALLBACKS ###
|
||||
dynamic_success_callbacks = None
|
||||
dynamic_async_success_callbacks = None
|
||||
if kwargs.get("success_callback", None) is not None and isinstance(
|
||||
kwargs["success_callback"], list
|
||||
):
|
||||
removed_async_items = []
|
||||
for index, callback in enumerate(kwargs["success_callback"]):
|
||||
if (
|
||||
inspect.iscoroutinefunction(callback)
|
||||
or callback == "dynamodb"
|
||||
or callback == "s3"
|
||||
):
|
||||
if dynamic_async_success_callbacks is not None and isinstance(
|
||||
dynamic_async_success_callbacks, list
|
||||
):
|
||||
dynamic_async_success_callbacks.append(callback)
|
||||
else:
|
||||
dynamic_async_success_callbacks = [callback]
|
||||
removed_async_items.append(index)
|
||||
# Pop the async items from success_callback in reverse order to avoid index issues
|
||||
for index in reversed(removed_async_items):
|
||||
kwargs["success_callback"].pop(index)
|
||||
dynamic_success_callbacks = kwargs.pop("success_callback")
|
||||
|
||||
if add_breadcrumb:
|
||||
add_breadcrumb(
|
||||
category="litellm.llm_call",
|
||||
message=f"Positional Args: {args}, Keyword Args: {kwargs}",
|
||||
level="info",
|
||||
)
|
||||
if "logger_fn" in kwargs:
|
||||
user_logger_fn = kwargs["logger_fn"]
|
||||
# INIT LOGGER - for user-specified integrations
|
||||
model = args[0] if len(args) > 0 else kwargs.get("model", None)
|
||||
call_type = original_function.__name__
|
||||
if (
|
||||
call_type == CallTypes.completion.value
|
||||
or call_type == CallTypes.acompletion.value
|
||||
):
|
||||
messages = None
|
||||
if len(args) > 1:
|
||||
messages = args[1]
|
||||
elif kwargs.get("messages", None):
|
||||
messages = kwargs["messages"]
|
||||
### PRE-CALL RULES ###
|
||||
if (
|
||||
isinstance(messages, list)
|
||||
and len(messages) > 0
|
||||
and isinstance(messages[0], dict)
|
||||
and "content" in messages[0]
|
||||
):
|
||||
rules_obj.pre_call_rules(
|
||||
input="".join(
|
||||
m.get("content", "")
|
||||
for m in messages
|
||||
if "content" in m and isinstance(m["content"], str)
|
||||
),
|
||||
model=model,
|
||||
)
|
||||
elif (
|
||||
call_type == CallTypes.embedding.value
|
||||
or call_type == CallTypes.aembedding.value
|
||||
):
|
||||
messages = args[1] if len(args) > 1 else kwargs["input"]
|
||||
elif (
|
||||
call_type == CallTypes.image_generation.value
|
||||
or call_type == CallTypes.aimage_generation.value
|
||||
):
|
||||
messages = args[0] if len(args) > 0 else kwargs["prompt"]
|
||||
elif (
|
||||
call_type == CallTypes.moderation.value
|
||||
or call_type == CallTypes.amoderation.value
|
||||
):
|
||||
messages = args[1] if len(args) > 1 else kwargs["input"]
|
||||
elif (
|
||||
call_type == CallTypes.atext_completion.value
|
||||
or call_type == CallTypes.text_completion.value
|
||||
):
|
||||
messages = args[0] if len(args) > 0 else kwargs["prompt"]
|
||||
elif (
|
||||
call_type == CallTypes.atranscription.value
|
||||
or call_type == CallTypes.transcription.value
|
||||
):
|
||||
_file_name: BinaryIO = args[1] if len(args) > 1 else kwargs["file"]
|
||||
messages = "audio_file"
|
||||
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
|
||||
logging_obj = Logging(
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
litellm_call_id=kwargs["litellm_call_id"],
|
||||
function_id=function_id,
|
||||
call_type=call_type,
|
||||
start_time=start_time,
|
||||
dynamic_success_callbacks=dynamic_success_callbacks,
|
||||
dynamic_async_success_callbacks=dynamic_async_success_callbacks,
|
||||
langfuse_public_key=kwargs.pop("langfuse_public_key", None),
|
||||
langfuse_secret=kwargs.pop("langfuse_secret", None),
|
||||
)
|
||||
## check if metadata is passed in
|
||||
litellm_params = {"api_base": ""}
|
||||
if "metadata" in kwargs:
|
||||
litellm_params["metadata"] = kwargs["metadata"]
|
||||
logging_obj.update_environment_variables(
|
||||
model=model,
|
||||
user="",
|
||||
optional_params={},
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
return logging_obj, kwargs
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.debug(
|
||||
f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
def client(original_function):
|
||||
global liteDebuggerClient, get_all_keys
|
||||
rules_obj = Rules()
|
||||
|
||||
def function_setup(
|
||||
start_time, *args, **kwargs
|
||||
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
||||
try:
|
||||
global callback_list, add_breadcrumb, user_logger_fn, Logging
|
||||
function_id = kwargs["id"] if "id" in kwargs else None
|
||||
if litellm.use_client or (
|
||||
"use_client" in kwargs and kwargs["use_client"] == True
|
||||
):
|
||||
if "lite_debugger" not in litellm.input_callback:
|
||||
litellm.input_callback.append("lite_debugger")
|
||||
if "lite_debugger" not in litellm.success_callback:
|
||||
litellm.success_callback.append("lite_debugger")
|
||||
if "lite_debugger" not in litellm.failure_callback:
|
||||
litellm.failure_callback.append("lite_debugger")
|
||||
if len(litellm.callbacks) > 0:
|
||||
for callback in litellm.callbacks:
|
||||
if callback not in litellm.input_callback:
|
||||
litellm.input_callback.append(callback)
|
||||
if callback not in litellm.success_callback:
|
||||
litellm.success_callback.append(callback)
|
||||
if callback not in litellm.failure_callback:
|
||||
litellm.failure_callback.append(callback)
|
||||
if callback not in litellm._async_success_callback:
|
||||
litellm._async_success_callback.append(callback)
|
||||
if callback not in litellm._async_failure_callback:
|
||||
litellm._async_failure_callback.append(callback)
|
||||
print_verbose(
|
||||
f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}"
|
||||
)
|
||||
if (
|
||||
len(litellm.input_callback) > 0
|
||||
or len(litellm.success_callback) > 0
|
||||
or len(litellm.failure_callback) > 0
|
||||
) and len(callback_list) == 0:
|
||||
callback_list = list(
|
||||
set(
|
||||
litellm.input_callback
|
||||
+ litellm.success_callback
|
||||
+ litellm.failure_callback
|
||||
)
|
||||
)
|
||||
set_callbacks(callback_list=callback_list, function_id=function_id)
|
||||
## ASYNC CALLBACKS
|
||||
if len(litellm.input_callback) > 0:
|
||||
removed_async_items = []
|
||||
for index, callback in enumerate(litellm.input_callback):
|
||||
if inspect.iscoroutinefunction(callback):
|
||||
litellm._async_input_callback.append(callback)
|
||||
removed_async_items.append(index)
|
||||
|
||||
# Pop the async items from input_callback in reverse order to avoid index issues
|
||||
for index in reversed(removed_async_items):
|
||||
litellm.input_callback.pop(index)
|
||||
|
||||
if len(litellm.success_callback) > 0:
|
||||
removed_async_items = []
|
||||
for index, callback in enumerate(litellm.success_callback):
|
||||
if inspect.iscoroutinefunction(callback):
|
||||
litellm._async_success_callback.append(callback)
|
||||
removed_async_items.append(index)
|
||||
elif callback == "dynamodb":
|
||||
# dynamo is an async callback, it's used for the proxy and needs to be async
|
||||
# we only support async dynamo db logging for acompletion/aembedding since that's used on proxy
|
||||
litellm._async_success_callback.append(callback)
|
||||
removed_async_items.append(index)
|
||||
|
||||
# Pop the async items from success_callback in reverse order to avoid index issues
|
||||
for index in reversed(removed_async_items):
|
||||
litellm.success_callback.pop(index)
|
||||
|
||||
if len(litellm.failure_callback) > 0:
|
||||
removed_async_items = []
|
||||
for index, callback in enumerate(litellm.failure_callback):
|
||||
if inspect.iscoroutinefunction(callback):
|
||||
litellm._async_failure_callback.append(callback)
|
||||
removed_async_items.append(index)
|
||||
|
||||
# Pop the async items from failure_callback in reverse order to avoid index issues
|
||||
for index in reversed(removed_async_items):
|
||||
litellm.failure_callback.pop(index)
|
||||
### DYNAMIC CALLBACKS ###
|
||||
dynamic_success_callbacks = None
|
||||
dynamic_async_success_callbacks = None
|
||||
if kwargs.get("success_callback", None) is not None and isinstance(
|
||||
kwargs["success_callback"], list
|
||||
):
|
||||
removed_async_items = []
|
||||
for index, callback in enumerate(kwargs["success_callback"]):
|
||||
if (
|
||||
inspect.iscoroutinefunction(callback)
|
||||
or callback == "dynamodb"
|
||||
or callback == "s3"
|
||||
):
|
||||
if dynamic_async_success_callbacks is not None and isinstance(
|
||||
dynamic_async_success_callbacks, list
|
||||
):
|
||||
dynamic_async_success_callbacks.append(callback)
|
||||
else:
|
||||
dynamic_async_success_callbacks = [callback]
|
||||
removed_async_items.append(index)
|
||||
# Pop the async items from success_callback in reverse order to avoid index issues
|
||||
for index in reversed(removed_async_items):
|
||||
kwargs["success_callback"].pop(index)
|
||||
dynamic_success_callbacks = kwargs.pop("success_callback")
|
||||
|
||||
if add_breadcrumb:
|
||||
add_breadcrumb(
|
||||
category="litellm.llm_call",
|
||||
message=f"Positional Args: {args}, Keyword Args: {kwargs}",
|
||||
level="info",
|
||||
)
|
||||
if "logger_fn" in kwargs:
|
||||
user_logger_fn = kwargs["logger_fn"]
|
||||
# INIT LOGGER - for user-specified integrations
|
||||
model = args[0] if len(args) > 0 else kwargs.get("model", None)
|
||||
call_type = original_function.__name__
|
||||
if (
|
||||
call_type == CallTypes.completion.value
|
||||
or call_type == CallTypes.acompletion.value
|
||||
):
|
||||
messages = None
|
||||
if len(args) > 1:
|
||||
messages = args[1]
|
||||
elif kwargs.get("messages", None):
|
||||
messages = kwargs["messages"]
|
||||
### PRE-CALL RULES ###
|
||||
if (
|
||||
isinstance(messages, list)
|
||||
and len(messages) > 0
|
||||
and isinstance(messages[0], dict)
|
||||
and "content" in messages[0]
|
||||
):
|
||||
rules_obj.pre_call_rules(
|
||||
input="".join(
|
||||
m.get("content", "")
|
||||
for m in messages
|
||||
if isinstance(m["content"], str)
|
||||
),
|
||||
model=model,
|
||||
)
|
||||
elif (
|
||||
call_type == CallTypes.embedding.value
|
||||
or call_type == CallTypes.aembedding.value
|
||||
):
|
||||
messages = args[1] if len(args) > 1 else kwargs["input"]
|
||||
elif (
|
||||
call_type == CallTypes.image_generation.value
|
||||
or call_type == CallTypes.aimage_generation.value
|
||||
):
|
||||
messages = args[0] if len(args) > 0 else kwargs["prompt"]
|
||||
elif (
|
||||
call_type == CallTypes.moderation.value
|
||||
or call_type == CallTypes.amoderation.value
|
||||
):
|
||||
messages = args[1] if len(args) > 1 else kwargs["input"]
|
||||
elif (
|
||||
call_type == CallTypes.atext_completion.value
|
||||
or call_type == CallTypes.text_completion.value
|
||||
):
|
||||
messages = args[0] if len(args) > 0 else kwargs["prompt"]
|
||||
elif (
|
||||
call_type == CallTypes.atranscription.value
|
||||
or call_type == CallTypes.transcription.value
|
||||
):
|
||||
_file_name: BinaryIO = args[1] if len(args) > 1 else kwargs["file"]
|
||||
messages = "audio_file"
|
||||
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
|
||||
logging_obj = Logging(
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
litellm_call_id=kwargs["litellm_call_id"],
|
||||
function_id=function_id,
|
||||
call_type=call_type,
|
||||
start_time=start_time,
|
||||
dynamic_success_callbacks=dynamic_success_callbacks,
|
||||
dynamic_async_success_callbacks=dynamic_async_success_callbacks,
|
||||
langfuse_public_key=kwargs.pop("langfuse_public_key", None),
|
||||
langfuse_secret=kwargs.pop("langfuse_secret", None),
|
||||
)
|
||||
## check if metadata is passed in
|
||||
litellm_params = {"api_base": ""}
|
||||
if "metadata" in kwargs:
|
||||
litellm_params["metadata"] = kwargs["metadata"]
|
||||
logging_obj.update_environment_variables(
|
||||
model=model,
|
||||
user="",
|
||||
optional_params={},
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
return logging_obj, kwargs
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.debug(
|
||||
f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}"
|
||||
)
|
||||
raise e
|
||||
|
||||
def check_coroutine(value) -> bool:
|
||||
if inspect.iscoroutine(value):
|
||||
return True
|
||||
|
@ -2688,7 +2698,9 @@ def client(original_function):
|
|||
|
||||
try:
|
||||
if logging_obj is None:
|
||||
logging_obj, kwargs = function_setup(start_time, *args, **kwargs)
|
||||
logging_obj, kwargs = function_setup(
|
||||
original_function, rules_obj, start_time, *args, **kwargs
|
||||
)
|
||||
kwargs["litellm_logging_obj"] = logging_obj
|
||||
|
||||
# CHECK FOR 'os.environ/' in kwargs
|
||||
|
@ -2996,7 +3008,9 @@ def client(original_function):
|
|||
|
||||
try:
|
||||
if logging_obj is None:
|
||||
logging_obj, kwargs = function_setup(start_time, *args, **kwargs)
|
||||
logging_obj, kwargs = function_setup(
|
||||
original_function, rules_obj, start_time, *args, **kwargs
|
||||
)
|
||||
kwargs["litellm_logging_obj"] = logging_obj
|
||||
|
||||
# [OPTIONAL] CHECK BUDGET
|
||||
|
@ -4907,37 +4921,11 @@ def get_optional_params(
|
|||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
|
||||
if temperature is not None:
|
||||
optional_params["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
optional_params["top_p"] = top_p
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
if n is not None:
|
||||
optional_params["candidate_count"] = n
|
||||
if stop is not None:
|
||||
if isinstance(stop, str):
|
||||
optional_params["stop_sequences"] = [stop]
|
||||
elif isinstance(stop, list):
|
||||
optional_params["stop_sequences"] = stop
|
||||
if max_tokens is not None:
|
||||
optional_params["max_output_tokens"] = max_tokens
|
||||
if response_format is not None and response_format["type"] == "json_object":
|
||||
optional_params["response_mime_type"] = "application/json"
|
||||
if tools is not None and isinstance(tools, list):
|
||||
from vertexai.preview import generative_models
|
||||
optional_params = litellm.VertexAIConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
gtool_func_declarations = []
|
||||
for tool in tools:
|
||||
gtool_func_declaration = generative_models.FunctionDeclaration(
|
||||
name=tool["function"]["name"],
|
||||
description=tool["function"].get("description", ""),
|
||||
parameters=tool["function"].get("parameters", {}),
|
||||
)
|
||||
gtool_func_declarations.append(gtool_func_declaration)
|
||||
optional_params["tools"] = [
|
||||
generative_models.Tool(function_declarations=gtool_func_declarations)
|
||||
]
|
||||
print_verbose(
|
||||
f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}"
|
||||
)
|
||||
|
@ -5639,17 +5627,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
|||
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
|
||||
return ["temperature", "top_p", "stream", "n", "stop", "max_tokens"]
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"stream",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"response_format",
|
||||
"n",
|
||||
"stop",
|
||||
]
|
||||
return litellm.VertexAIConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||
elif custom_llm_provider == "aleph_alpha":
|
||||
|
@ -6163,7 +6141,13 @@ def get_model_info(model: str):
|
|||
"mode": "chat",
|
||||
}
|
||||
else:
|
||||
raise Exception()
|
||||
"""
|
||||
Check if model in model cost map
|
||||
"""
|
||||
if model in litellm.model_cost:
|
||||
return litellm.model_cost[model]
|
||||
else:
|
||||
raise Exception()
|
||||
except:
|
||||
raise Exception(
|
||||
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
|
||||
|
@ -7868,6 +7852,15 @@ def exception_type(
|
|||
llm_provider="vertex_ai",
|
||||
response=original_exception.response,
|
||||
)
|
||||
elif "None Unknown Error." in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise APIError(
|
||||
message=f"VertexAIException - {error_str}",
|
||||
status_code=500,
|
||||
model=model,
|
||||
llm_provider="vertex_ai",
|
||||
request=original_exception.request,
|
||||
)
|
||||
elif "403" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise BadRequestError(
|
||||
|
@ -8882,7 +8875,16 @@ class CustomStreamWrapper:
|
|||
raise e
|
||||
|
||||
def check_special_tokens(self, chunk: str, finish_reason: Optional[str]):
|
||||
"""
|
||||
Output parse <s> / </s> special tokens for sagemaker + hf streaming.
|
||||
"""
|
||||
hold = False
|
||||
if (
|
||||
self.custom_llm_provider != "huggingface"
|
||||
and self.custom_llm_provider != "sagemaker"
|
||||
):
|
||||
return hold, chunk
|
||||
|
||||
if finish_reason:
|
||||
for token in self.special_tokens:
|
||||
if token in chunk:
|
||||
|
@ -8898,6 +8900,7 @@ class CustomStreamWrapper:
|
|||
for token in self.special_tokens:
|
||||
if len(curr_chunk) < len(token) and curr_chunk in token:
|
||||
hold = True
|
||||
self.holding_chunk = curr_chunk
|
||||
elif len(curr_chunk) >= len(token):
|
||||
if token in curr_chunk:
|
||||
self.holding_chunk = curr_chunk.replace(token, "")
|
||||
|
@ -9979,6 +9982,7 @@ class CustomStreamWrapper:
|
|||
f"model_response.choices[0].delta: {model_response.choices[0].delta}; completion_obj: {completion_obj}"
|
||||
)
|
||||
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
|
||||
|
||||
## RETURN ARG
|
||||
if (
|
||||
"content" in completion_obj
|
||||
|
@ -10051,7 +10055,6 @@ class CustomStreamWrapper:
|
|||
elif self.received_finish_reason is not None:
|
||||
if self.sent_last_chunk == True:
|
||||
raise StopIteration
|
||||
|
||||
# flush any remaining holding chunk
|
||||
if len(self.holding_chunk) > 0:
|
||||
if model_response.choices[0].delta.content is None:
|
||||
|
@ -10627,7 +10630,9 @@ def trim_messages(
|
|||
if max_tokens is None:
|
||||
# Check if model is valid
|
||||
if model in litellm.model_cost:
|
||||
max_tokens_for_model = litellm.model_cost[model].get("max_input_tokens", litellm.model_cost[model]["max_tokens"])
|
||||
max_tokens_for_model = litellm.model_cost[model].get(
|
||||
"max_input_tokens", litellm.model_cost[model]["max_tokens"]
|
||||
)
|
||||
max_tokens = int(max_tokens_for_model * trim_ratio)
|
||||
else:
|
||||
# if user did not specify max (input) tokens
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue