diff --git a/litellm/utils.py b/litellm/utils.py index 4c1cb2efd..3a48958fc 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -27,7 +27,6 @@ from dataclasses import ( dataclass, field, ) # for storing API inputs, outputs, and metadata - encoding = tiktoken.get_encoding("cl100k_base") import importlib.metadata from .integrations.traceloop import TraceloopLogger @@ -57,17 +56,16 @@ from .exceptions import ( APIConnectionError, APIError, BudgetExceededError, - UnprocessableEntityError, + UnprocessableEntityError ) from typing import cast, List, Dict, Union, Optional, Literal from .caching import Cache from concurrent.futures import ThreadPoolExecutor - ####### ENVIRONMENT VARIABLES #################### # Adjust to your specific application needs / system capabilities. -MAX_THREADS = 100 +MAX_THREADS = 100 -# Create a ThreadPoolExecutor +# Create a ThreadPoolExecutor executor = ThreadPoolExecutor(max_workers=MAX_THREADS) dotenv.load_dotenv() # Loading env variables using dotenv sentry_sdk_instance = None @@ -113,7 +111,6 @@ last_fetched_at_keys = None # 'usage': {'prompt_tokens': 18, 'completion_tokens': 23, 'total_tokens': 41} # } - class UnsupportedParamsError(Exception): def __init__(self, status_code, message): self.status_code = status_code @@ -125,81 +122,64 @@ class UnsupportedParamsError(Exception): ) # Call the base class constructor with the parameters it needs -def _generate_id(): # private helper function - return "chatcmpl-" + str(uuid.uuid4()) +def _generate_id(): # private helper function + return 'chatcmpl-' + str(uuid.uuid4()) - -def map_finish_reason( - finish_reason: str, -): # openai supports 5 stop sequences - 'stop', 'length', 'function_call', 'content_filter', 'null' +def map_finish_reason(finish_reason: str): # openai supports 5 stop sequences - 'stop', 'length', 'function_call', 'content_filter', 'null' # anthropic mapping if finish_reason == "stop_sequence": return "stop" # cohere mapping - https://docs.cohere.com/reference/generate - elif finish_reason == "COMPLETE": + elif finish_reason == "COMPLETE": return "stop" - elif finish_reason == "MAX_TOKENS": # cohere + vertex ai + elif finish_reason == "MAX_TOKENS": # cohere + vertex ai return "length" - elif finish_reason == "ERROR_TOXIC": + elif finish_reason == "ERROR_TOXIC": return "content_filter" - elif ( - finish_reason == "ERROR" - ): # openai currently doesn't support an 'error' finish reason + elif finish_reason == "ERROR": # openai currently doesn't support an 'error' finish reason return "stop" # huggingface mapping https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/generate_stream elif finish_reason == "eos_token" or finish_reason == "stop_sequence": return "stop" - elif ( - finish_reason == "FINISH_REASON_UNSPECIFIED" or finish_reason == "STOP" - ): # vertex ai - got from running `print(dir(response_obj.candidates[0].finish_reason))`: ['FINISH_REASON_UNSPECIFIED', 'MAX_TOKENS', 'OTHER', 'RECITATION', 'SAFETY', 'STOP',] + elif finish_reason == "FINISH_REASON_UNSPECIFIED" or finish_reason == "STOP": # vertex ai - got from running `print(dir(response_obj.candidates[0].finish_reason))`: ['FINISH_REASON_UNSPECIFIED', 'MAX_TOKENS', 'OTHER', 'RECITATION', 'SAFETY', 'STOP',] return "stop" - elif finish_reason == "SAFETY": # vertex ai + elif finish_reason == "SAFETY": # vertex ai return "content_filter" return finish_reason - class FunctionCall(OpenAIObject): arguments: str name: str - class Function(OpenAIObject): arguments: str name: str - class ChatCompletionMessageToolCall(OpenAIObject): id: str function: Function type: str - class Message(OpenAIObject): - def __init__( - self, - content="default", - role="assistant", - logprobs=None, - function_call=None, - tool_calls=None, - **params, - ): + def __init__(self, content="default", role="assistant", logprobs=None, function_call=None, tool_calls=None, **params): super(Message, self).__init__(**params) self.content = content self.role = role - if function_call is not None: + if function_call is not None: self.function_call = FunctionCall(**function_call) if tool_calls is not None: self.tool_calls = [] for tool_call in tool_calls: - self.tool_calls.append(ChatCompletionMessageToolCall(**tool_call)) + self.tool_calls.append( + ChatCompletionMessageToolCall(**tool_call) + ) if logprobs is not None: - self._logprobs = logprobs + self._logprobs = logprobs def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -210,7 +190,7 @@ class Message(OpenAIObject): def json(self, **kwargs): try: - return self.model_dump() # noqa + return self.model_dump() # noqa except: # if using pydantic v1 return self.dict() @@ -221,7 +201,7 @@ class Delta(OpenAIObject): super(Delta, self).__init__(**params) self.content = content self.role = role - + def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) @@ -229,7 +209,7 @@ class Delta(OpenAIObject): def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -242,15 +222,13 @@ class Delta(OpenAIObject): class Choices(OpenAIObject): def __init__(self, finish_reason=None, index=0, message=None, **params): super(Choices, self).__init__(**params) - self.finish_reason = ( - map_finish_reason(finish_reason) or "stop" - ) # set finish_reason for all responses + self.finish_reason = map_finish_reason(finish_reason) or "stop" # set finish_reason for all responses self.index = index if message is None: self.message = Message(content=None) else: self.message = message - + def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) @@ -258,7 +236,7 @@ class Choices(OpenAIObject): def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -267,11 +245,8 @@ class Choices(OpenAIObject): # Allow dictionary-style assignment of attributes setattr(self, key, value) - class Usage(OpenAIObject): - def __init__( - self, prompt_tokens=None, completion_tokens=None, total_tokens=None, **params - ): + def __init__(self, prompt_tokens=None, completion_tokens=None, total_tokens=None, **params): super(Usage, self).__init__(**params) if prompt_tokens: self.prompt_tokens = prompt_tokens @@ -279,15 +254,15 @@ class Usage(OpenAIObject): self.completion_tokens = completion_tokens if total_tokens: self.total_tokens = total_tokens - + def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) - + def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -296,11 +271,8 @@ class Usage(OpenAIObject): # Allow dictionary-style assignment of attributes setattr(self, key, value) - class StreamingChoices(OpenAIObject): - def __init__( - self, finish_reason=None, index=0, delta: Optional[Delta] = None, **params - ): + def __init__(self, finish_reason=None, index=0, delta: Optional[Delta]=None, **params): super(StreamingChoices, self).__init__(**params) if finish_reason: self.finish_reason = finish_reason @@ -311,15 +283,15 @@ class StreamingChoices(OpenAIObject): self.delta = delta else: self.delta = Delta() - + def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) - + def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -328,8 +300,7 @@ class StreamingChoices(OpenAIObject): # Allow dictionary-style assignment of attributes setattr(self, key, value) - -class ModelResponse(OpenAIObject): +class ModelResponse(OpenAIObject): id: str """A unique identifier for the completion.""" @@ -357,20 +328,7 @@ class ModelResponse(OpenAIObject): _hidden_params: dict = {} - def __init__( - self, - id=None, - choices=None, - created=None, - model=None, - object=None, - system_fingerprint=None, - usage=None, - stream=False, - response_ms=None, - hidden_params=None, - **params, - ): + def __init__(self, id=None, choices=None, created=None, model=None, object=None, system_fingerprint=None, usage=None, stream=False, response_ms=None, hidden_params=None, **params): if stream: object = "chat.completion.chunk" choices = [StreamingChoices()] @@ -395,25 +353,16 @@ class ModelResponse(OpenAIObject): usage = Usage() if hidden_params: self._hidden_params = hidden_params - super().__init__( - id=id, - choices=choices, - created=created, - model=model, - object=object, - system_fingerprint=system_fingerprint, - usage=usage, - **params, - ) - + super().__init__(id=id, choices=choices, created=created, model=model, object=object, system_fingerprint=system_fingerprint, usage=usage, **params) + def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) - + def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -421,15 +370,14 @@ class ModelResponse(OpenAIObject): def __setitem__(self, key, value): # Allow dictionary-style assignment of attributes setattr(self, key, value) - + def json(self, **kwargs): try: - return self.model_dump() # noqa + return self.model_dump() # noqa except: # if using pydantic v1 return self.dict() - class Embedding(OpenAIObject): embedding: list = [] index: int @@ -438,7 +386,7 @@ class Embedding(OpenAIObject): def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -447,7 +395,6 @@ class Embedding(OpenAIObject): # Allow dictionary-style assignment of attributes setattr(self, key, value) - class EmbeddingResponse(OpenAIObject): model: Optional[str] = None """The model used for embedding.""" @@ -461,19 +408,17 @@ class EmbeddingResponse(OpenAIObject): usage: Optional[Usage] = None """Usage statistics for the embedding request.""" - def __init__( - self, model=None, usage=None, stream=False, response_ms=None, data=None - ): + def __init__(self, model=None, usage=None, stream=False, response_ms=None, data=None): object = "list" if response_ms: _response_ms = response_ms else: _response_ms = None - if data: + if data: data = data - else: + else: data = None - + if usage: usage = usage else: @@ -485,11 +430,11 @@ class EmbeddingResponse(OpenAIObject): def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) - + def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -497,15 +442,14 @@ class EmbeddingResponse(OpenAIObject): def __setitem__(self, key, value): # Allow dictionary-style assignment of attributes setattr(self, key, value) - + def json(self, **kwargs): try: - return self.model_dump() # noqa + return self.model_dump() # noqa except: # if using pydantic v1 return self.dict() - class TextChoices(OpenAIObject): def __init__(self, finish_reason=None, index=0, text=None, logprobs=None, **params): super(TextChoices, self).__init__(**params) @@ -522,7 +466,7 @@ class TextChoices(OpenAIObject): self.logprobs = [] else: self.logprobs = logprobs - + def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) @@ -530,7 +474,7 @@ class TextChoices(OpenAIObject): def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -539,7 +483,6 @@ class TextChoices(OpenAIObject): # Allow dictionary-style assignment of attributes setattr(self, key, value) - class TextCompletionResponse(OpenAIObject): """ { @@ -558,18 +501,7 @@ class TextCompletionResponse(OpenAIObject): "usage": response["usage"] } """ - - def __init__( - self, - id=None, - choices=None, - created=None, - model=None, - usage=None, - stream=False, - response_ms=None, - **params, - ): + def __init__(self, id=None, choices=None, created=None, model=None, usage=None, stream=False, response_ms=None, **params): super(TextCompletionResponse, self).__init__(**params) if stream: self.object = "text_completion.chunk" @@ -594,10 +526,9 @@ class TextCompletionResponse(OpenAIObject): self.usage = usage else: self.usage = Usage() - self._hidden_params = ( - {} - ) # used in case users want to access the original model response + self._hidden_params = {} # used in case users want to access the original model response + def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) @@ -605,7 +536,7 @@ class TextCompletionResponse(OpenAIObject): def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -617,34 +548,38 @@ class TextCompletionResponse(OpenAIObject): class ImageResponse(OpenAIObject): created: Optional[int] = None - + data: Optional[list] = None + usage: Optional[dict] = None + def __init__(self, created=None, data=None, response_ms=None): if response_ms: _response_ms = response_ms else: _response_ms = None - if data: + if data: data = data - else: + else: data = None - + if created: created = created else: created = None - + super().__init__(data=data, created=created) + self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) - + def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -652,69 +587,52 @@ class ImageResponse(OpenAIObject): def __setitem__(self, key, value): # Allow dictionary-style assignment of attributes setattr(self, key, value) - + def json(self, **kwargs): try: - return self.model_dump() # noqa + return self.model_dump() # noqa except: # if using pydantic v1 return self.dict() - - + ############################################################ def print_verbose(print_statement): try: if litellm.set_verbose: - print(print_statement) # noqa + print(print_statement) # noqa except: pass - ####### LOGGING ################### from enum import Enum - class CallTypes(Enum): - embedding = "embedding" - completion = "completion" - acompletion = "acompletion" - aembedding = "aembedding" - image_generation = "image_generation" - aimage_generation = "aimage_generation" - + embedding = 'embedding' + completion = 'completion' + acompletion = 'acompletion' + aembedding = 'aembedding' + image_generation = 'image_generation' + aimage_generation = 'aimage_generation' # Logging function -> log the exact model details + what's being sent | Non-Blocking class Logging: global supabaseClient, liteDebuggerClient, promptLayerLogger, weightsBiasesLogger, langsmithLogger, capture_exception, add_breadcrumb, llmonitorLogger - def __init__( - self, - model, - messages, - stream, - call_type, - start_time, - litellm_call_id, - function_id, - ): + def __init__(self, model, messages, stream, call_type, start_time, litellm_call_id, function_id): if call_type not in [item.value for item in CallTypes]: allowed_values = ", ".join([item.value for item in CallTypes]) - raise ValueError( - f"Invalid call_type {call_type}. Allowed values: {allowed_values}" - ) + raise ValueError(f"Invalid call_type {call_type}. Allowed values: {allowed_values}") self.model = model self.messages = messages self.stream = stream - self.start_time = start_time # log the call start time + self.start_time = start_time # log the call start time self.call_type = call_type self.litellm_call_id = litellm_call_id self.function_id = function_id - self.streaming_chunks = [] # for generating complete stream response + self.streaming_chunks = [] # for generating complete stream response self.model_call_details = {} - - def update_environment_variables( - self, model, user, optional_params, litellm_params, **additional_params - ): + + def update_environment_variables(self, model, user, optional_params, litellm_params, **additional_params): self.optional_params = optional_params self.model = model self.user = user @@ -731,10 +649,10 @@ class Logging: "user": user, "call_type": str(self.call_type), **self.optional_params, - **additional_params, + **additional_params } - def _pre_call(self, input, api_key, model=None, additional_args={}): + def _pre_call(self, input, api_key, model=None, additional_args={}): """ Common helper function across the sync + async pre-call function """ @@ -744,43 +662,31 @@ class Logging: self.model_call_details["additional_args"] = additional_args self.model_call_details["log_event_type"] = "pre_api_call" if ( - model - ): # if model name was changes pre-call, overwrite the initial model call name with the new one - self.model_call_details["model"] = model + model + ): # if model name was changes pre-call, overwrite the initial model call name with the new one + self.model_call_details["model"] = model def pre_call(self, input, api_key, model=None, additional_args={}): # Log the exact input to the LLM API - litellm.error_logs["PRE_CALL"] = locals() + litellm.error_logs['PRE_CALL'] = locals() try: - self._pre_call( - input=input, - api_key=api_key, - model=model, - additional_args=additional_args, - ) + self._pre_call(input=input, api_key=api_key, model=model, additional_args=additional_args) # User Logging -> if you pass in a custom logging function headers = additional_args.get("headers", {}) - if headers is None: + if headers is None: headers = {} data = additional_args.get("complete_input_dict", {}) api_base = additional_args.get("api_base", "") - masked_headers = { - k: (v[:-20] + "*" * 20) if (isinstance(v, str) and len(v) > 20) else v - for k, v in headers.items() - } - formatted_headers = " ".join( - [f"-H '{k}: {v}'" for k, v in masked_headers.items()] - ) + masked_headers = {k: (v[:-20] + '*' * 20) if (isinstance(v, str) and len(v) > 20) else v for k, v in headers.items()} + formatted_headers = " ".join([f"-H '{k}: {v}'" for k, v in masked_headers.items()]) print_verbose(f"PRE-API-CALL ADDITIONAL ARGS: {additional_args}") curl_command = "\n\nPOST Request Sent from LiteLLM:\n" curl_command += "curl -X POST \\\n" curl_command += f"{api_base} \\\n" - curl_command += ( - f"{formatted_headers} \\\n" if formatted_headers.strip() != "" else "" - ) + curl_command += f"{formatted_headers} \\\n" if formatted_headers.strip() != "" else "" curl_command += f"-d '{str(data)}'\n" if additional_args.get("request_str", None) is not None: # print the sagemaker / bedrock client request @@ -801,17 +707,10 @@ class Logging: if litellm.max_budget and self.stream: start_time = self.start_time - end_time = ( - self.start_time - ) # no time has passed as the call hasn't been made yet + end_time = self.start_time # no time has passed as the call hasn't been made yet time_diff = (end_time - start_time).total_seconds() float_diff = float(time_diff) - litellm._current_cost += litellm.completion_cost( - model=self.model, - prompt="".join(message["content"] for message in self.messages), - completion="", - total_time=float_diff, - ) + litellm._current_cost += litellm.completion_cost(model=self.model, prompt="".join(message["content"] for message in self.messages), completion="", total_time=float_diff) # Input Integration Logging -> If you want to log the fact that an attempt to call the model was made for callback in litellm.input_callback: @@ -830,9 +729,7 @@ class Logging: ) elif callback == "lite_debugger": - print_verbose( - f"reaches litedebugger for logging! - model_call_details {self.model_call_details}" - ) + print_verbose(f"reaches litedebugger for logging! - model_call_details {self.model_call_details}") model = self.model_call_details["model"] messages = self.model_call_details["input"] print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") @@ -844,7 +741,7 @@ class Logging: litellm_params=self.model_call_details["litellm_params"], optional_params=self.model_call_details["optional_params"], print_verbose=print_verbose, - call_type=self.call_type, + call_type=self.call_type ) elif callback == "sentry" and add_breadcrumb: print_verbose("reaches sentry breadcrumbing") @@ -853,19 +750,19 @@ class Logging: message=f"Model Call Details pre-call: {self.model_call_details}", level="info", ) - elif isinstance(callback, CustomLogger): # custom logger class + elif isinstance(callback, CustomLogger): # custom logger class callback.log_pre_api_call( model=self.model, messages=self.messages, kwargs=self.model_call_details, ) - elif callable(callback): # custom logger functions + elif callable(callback): # custom logger functions customLogger.log_input_event( model=self.model, messages=self.messages, kwargs=self.model_call_details, print_verbose=print_verbose, - callback_func=callback, + callback_func=callback ) except Exception as e: traceback.print_exc() @@ -887,48 +784,37 @@ class Logging: if capture_exception: # log this error to sentry for debugging capture_exception(e) - async def async_pre_call( - self, result=None, start_time=None, end_time=None, **kwargs - ): + async def async_pre_call(self, result=None, start_time=None, end_time=None, **kwargs): """ - Â Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. + Â Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. """ - start_time, end_time, result = self._success_handler_helper_fn( - start_time=start_time, end_time=end_time, result=result - ) + start_time, end_time, result = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result) print_verbose(f"Async input callbacks: {litellm._async_input_callback}") for callback in litellm._async_input_callback: - try: - if isinstance(callback, CustomLogger): # custom logger class + try: + if isinstance(callback, CustomLogger): # custom logger class print_verbose(f"Async input callbacks: CustomLogger") - asyncio.create_task( - callback.async_log_input_event( + asyncio.create_task(callback.async_log_input_event( model=self.model, messages=self.messages, kwargs=self.model_call_details, - ) - ) - if callable(callback): # custom logger functions + )) + if callable(callback): # custom logger functions print_verbose(f"Async success callbacks: async_log_event") - asyncio.create_task( - customLogger.async_log_input_event( - model=self.model, - messages=self.messages, - kwargs=self.model_call_details, - print_verbose=print_verbose, - callback_func=callback, - ) - ) - except: + asyncio.create_task(customLogger.async_log_input_event( + model=self.model, + messages=self.messages, + kwargs=self.model_call_details, + print_verbose=print_verbose, + callback_func=callback + )) + except: print_verbose( f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}" ) - - def post_call( - self, original_response, input=None, api_key=None, additional_args={} - ): + def post_call(self, original_response, input=None, api_key=None, additional_args={}): # Log the exact result from the LLM API, for streaming - log the type of response received - litellm.error_logs["POST_CALL"] = locals() + litellm.error_logs['POST_CALL'] = locals() try: self.model_call_details["input"] = input self.model_call_details["api_key"] = api_key @@ -937,15 +823,11 @@ class Logging: self.model_call_details["log_event_type"] = "post_api_call" # User Logging -> if you pass in a custom logging function - print_verbose( - f"RAW RESPONSE:\n{self.model_call_details.get('original_response', self.model_call_details)}\n\n" - ) + print_verbose(f"RAW RESPONSE:\n{self.model_call_details.get('original_response', self.model_call_details)}\n\n") print_verbose( f"Logging Details Post-API Call: logger_fn - {self.logger_fn} | callable(logger_fn) - {callable(self.logger_fn)}" ) - print_verbose( - f"Logging Details Post-API Call: LiteLLM Params: {self.model_call_details}" - ) + print_verbose(f"Logging Details Post-API Call: LiteLLM Params: {self.model_call_details}") if self.logger_fn and callable(self.logger_fn): try: self.logger_fn( @@ -955,7 +837,7 @@ class Logging: print_verbose( f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}" ) - + # Input Integration Logging -> If you want to log the fact that an attempt to call the model was made for callback in litellm.input_callback: try: @@ -966,8 +848,8 @@ class Logging: original_response=original_response, litellm_call_id=self.litellm_params["litellm_call_id"], print_verbose=print_verbose, - call_type=self.call_type, - stream=self.stream, + call_type = self.call_type, + stream = self.stream, ) elif callback == "sentry" and add_breadcrumb: print_verbose("reaches sentry breadcrumbing") @@ -976,12 +858,12 @@ class Logging: message=f"Model Call Details post-call: {self.model_call_details}", level="info", ) - elif isinstance(callback, CustomLogger): # custom logger class + elif isinstance(callback, CustomLogger): # custom logger class callback.log_post_api_call( kwargs=self.model_call_details, response_obj=None, start_time=self.start_time, - end_time=None, + end_time=None ) except Exception as e: print_verbose( @@ -997,11 +879,9 @@ class Logging: f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}" ) pass - - def _success_handler_helper_fn( - self, result=None, start_time=None, end_time=None, cache_hit=None - ): - try: + + def _success_handler_helper_fn(self, result=None, start_time=None, end_time=None, cache_hit=None): + try: if start_time is None: start_time = self.start_time if end_time is None: @@ -1013,67 +893,42 @@ class Logging: if litellm.max_budget and self.stream: time_diff = (end_time - start_time).total_seconds() float_diff = float(time_diff) - litellm._current_cost += litellm.completion_cost( - model=self.model, - prompt="", - completion=result["content"], - total_time=float_diff, - ) + litellm._current_cost += litellm.completion_cost(model=self.model, prompt="", completion=result["content"], total_time=float_diff) return start_time, end_time, result - except Exception as e: + except Exception as e: print_verbose(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}") - def success_handler( - self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs - ): - print_verbose(f"Logging Details LiteLLM-Success Call") + def success_handler(self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs): + print_verbose( + f"Logging Details LiteLLM-Success Call" + ) # print(f"original response in success handler: {self.model_call_details['original_response']}") try: - print_verbose(f"success callbacks: {litellm.success_callback}") + print_verbose(f"success callbacks: {litellm.success_callback}") ## BUILD COMPLETE STREAMED RESPONSE complete_streaming_response = None - if ( - self.stream - and self.model_call_details.get("litellm_params", {}).get( - "acompletion", False - ) - == False - ): # only call stream chunk builder if it's not acompletion() - if ( - result.choices[0].finish_reason is not None - ): # if it's the last chunk + if self.stream and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # only call stream chunk builder if it's not acompletion() + if result.choices[0].finish_reason is not None: # if it's the last chunk self.streaming_chunks.append(result) # print_verbose(f"final set of received chunks: {self.streaming_chunks}") - try: - complete_streaming_response = litellm.stream_chunk_builder( - self.streaming_chunks, - messages=self.model_call_details.get("messages", None), - ) - except: + try: + complete_streaming_response = litellm.stream_chunk_builder(self.streaming_chunks, messages=self.model_call_details.get("messages", None)) + except: complete_streaming_response = None else: self.streaming_chunks.append(result) - if complete_streaming_response: - self.model_call_details[ - "complete_streaming_response" - ] = complete_streaming_response + if complete_streaming_response: + self.model_call_details["complete_streaming_response"] = complete_streaming_response - start_time, end_time, result = self._success_handler_helper_fn( - start_time=start_time, - end_time=end_time, - result=result, - cache_hit=cache_hit, - ) + start_time, end_time, result = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit) for callback in litellm.success_callback: try: if callback == "lite_debugger": print_verbose("reaches lite_debugger for logging!") print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") - print_verbose( - f"liteDebuggerClient details function {self.call_type} and stream set to {self.stream}" - ) + print_verbose(f"liteDebuggerClient details function {self.call_type} and stream set to {self.stream}") liteDebuggerClient.log_event( end_user=kwargs.get("user", "default"), response_obj=result, @@ -1081,8 +936,8 @@ class Logging: end_time=end_time, litellm_call_id=self.litellm_call_id, print_verbose=print_verbose, - call_type=self.call_type, - stream=self.stream, + call_type = self.call_type, + stream = self.stream, ) if callback == "promptlayer": print_verbose("reaches promptlayer for logging!") @@ -1095,8 +950,8 @@ class Logging: ) if callback == "supabase": print_verbose("reaches supabase for logging!") - kwargs = self.model_call_details - + kwargs=self.model_call_details + # this only logs streaming once, complete_streaming_response exists i.e when stream ends if self.stream: if "complete_streaming_response" not in kwargs: @@ -1104,7 +959,7 @@ class Logging: else: print_verbose("reaches supabase for streaming logging!") result = kwargs["complete_streaming_response"] - + model = kwargs["model"] messages = kwargs["messages"] optional_params = kwargs.get("optional_params", {}) @@ -1116,9 +971,7 @@ class Logging: response_obj=result, start_time=start_time, end_time=end_time, - litellm_call_id=litellm_params.get( - "litellm_call_id", str(uuid.uuid4()) - ), + litellm_call_id=litellm_params.get("litellm_call_id", str(uuid.uuid4())), print_verbose=print_verbose, ) if callback == "wandb": @@ -1143,16 +996,10 @@ class Logging: print_verbose("reaches llmonitor for logging!") model = self.model - input = self.model_call_details.get( - "messages", self.model_call_details.get("input", None) - ) + input = self.model_call_details.get("messages", self.model_call_details.get("input", None)) # if contains input, it's 'embedding', otherwise 'llm' - type = ( - "embed" - if self.call_type == CallTypes.embedding.value - else "llm" - ) + type = "embed" if self.call_type == CallTypes.embedding.value else "llm" llmonitorLogger.log_event( type=type, @@ -1182,10 +1029,8 @@ class Logging: global langFuseLogger print_verbose("reaches langfuse for logging!") kwargs = {} - for k, v in self.model_call_details.items(): - if ( - k != "original_response" - ): # copy.deepcopy raises errors as this could be a coroutine + for k, v in self.model_call_details.items(): + if k != "original_response": # copy.deepcopy raises errors as this could be a coroutine kwargs[k] = v # this only logs streaming once, complete_streaming_response exists i.e when stream ends if self.stream: @@ -1201,7 +1046,6 @@ class Logging: response_obj=result, start_time=start_time, end_time=end_time, - user_id=self.model_call_details.get("user", "default"), print_verbose=print_verbose, ) if callback == "cache" and litellm.cache is not None: @@ -1210,21 +1054,17 @@ class Logging: kwargs = self.model_call_details if self.stream: if "complete_streaming_response" not in kwargs: - print_verbose( - f"success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n" - ) + print_verbose(f"success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n") return else: - print_verbose( - "success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache" - ) + print_verbose("success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache") result = kwargs["complete_streaming_response"] # only add to cache once we have a complete streaming response litellm.cache.add_cache(result, **kwargs) if callback == "traceloop": deep_copy = {} - for k, v in self.model_call_details.items(): - if k != "original_response": + for k, v in self.model_call_details.items(): + if k != "original_response": deep_copy[k] = v traceloopLogger.log_event( kwargs=deep_copy, @@ -1233,32 +1073,18 @@ class Logging: end_time=end_time, print_verbose=print_verbose, ) - elif ( - isinstance(callback, CustomLogger) - and self.model_call_details.get("litellm_params", {}).get( - "acompletion", False - ) - == False - and self.model_call_details.get("litellm_params", {}).get( - "aembedding", False - ) - == False - ): # custom logger class - print_verbose(f"success callbacks: Running Custom Logger Class") + elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False and self.model_call_details.get("litellm_params", {}).get("aembedding", False) == False: # custom logger class + print_verbose(f"success callbacks: Running Custom Logger Class") if self.stream and complete_streaming_response is None: callback.log_stream_event( kwargs=self.model_call_details, response_obj=result, start_time=start_time, - end_time=end_time, - ) + end_time=end_time + ) else: if self.stream and complete_streaming_response: - self.model_call_details[ - "complete_response" - ] = self.model_call_details.get( - "complete_streaming_response", {} - ) + self.model_call_details["complete_response"] = self.model_call_details.get("complete_streaming_response", {}) result = self.model_call_details["complete_response"] callback.log_success_event( kwargs=self.model_call_details, @@ -1266,17 +1092,15 @@ class Logging: start_time=start_time, end_time=end_time, ) - if callable(callback): # custom logger functions - print_verbose( - f"success callbacks: Running Custom Callback Function" - ) + if callable(callback): # custom logger functions + print_verbose(f"success callbacks: Running Custom Callback Function") customLogger.log_event( kwargs=self.model_call_details, response_obj=result, start_time=start_time, end_time=end_time, print_verbose=print_verbose, - callback_func=callback, + callback_func=callback ) except Exception as e: @@ -1294,77 +1118,60 @@ class Logging: ) pass - async def async_success_handler( - self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs - ): + async def async_success_handler(self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs): """ Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. """ print_verbose(f"Async success callbacks: {litellm._async_success_callback}") ## BUILD COMPLETE STREAMED RESPONSE complete_streaming_response = None - if self.stream: - if result.choices[0].finish_reason is not None: # if it's the last chunk + if self.stream: + if result.choices[0].finish_reason is not None: # if it's the last chunk self.streaming_chunks.append(result) # print_verbose(f"final set of received chunks: {self.streaming_chunks}") - try: - complete_streaming_response = litellm.stream_chunk_builder( - self.streaming_chunks, - messages=self.model_call_details.get("messages", None), - ) + try: + complete_streaming_response = litellm.stream_chunk_builder(self.streaming_chunks, messages=self.model_call_details.get("messages", None)) except Exception as e: - print_verbose( - f"Error occurred building stream chunk: {traceback.format_exc()}" - ) + print_verbose(f"Error occurred building stream chunk: {traceback.format_exc()}") complete_streaming_response = None else: self.streaming_chunks.append(result) - if complete_streaming_response: + if complete_streaming_response: print_verbose("Async success callbacks: Got a complete streaming response") - self.model_call_details[ - "complete_streaming_response" - ] = complete_streaming_response - start_time, end_time, result = self._success_handler_helper_fn( - start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit - ) + self.model_call_details["complete_streaming_response"] = complete_streaming_response + start_time, end_time, result = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit) for callback in litellm._async_success_callback: - try: + try: if callback == "cache" and litellm.cache is not None: # set_cache once complete streaming response is built print_verbose("async success_callback: reaches cache for logging!") kwargs = self.model_call_details if self.stream: if "complete_streaming_response" not in kwargs: - print_verbose( - f"async success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n" - ) + print_verbose(f"async success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n") return else: - print_verbose( - "async success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache" - ) + print_verbose("async success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache") result = kwargs["complete_streaming_response"] # only add to cache once we have a complete streaming response litellm.cache.add_cache(result, **kwargs) - if isinstance(callback, CustomLogger): # custom logger class + if isinstance(callback, CustomLogger): # custom logger class print_verbose(f"Async success callbacks: CustomLogger") if self.stream: if "complete_streaming_response" in self.model_call_details: await callback.async_log_success_event( kwargs=self.model_call_details, - response_obj=self.model_call_details[ - "complete_streaming_response" - ], + response_obj=self.model_call_details["complete_streaming_response"], start_time=start_time, end_time=end_time, ) - else: - await callback.async_log_stream_event( # [TODO]: move this to being an async log stream event function + else: + await callback.async_log_stream_event( # [TODO]: move this to being an async log stream event function kwargs=self.model_call_details, response_obj=result, start_time=start_time, - end_time=end_time, - ) + end_time=end_time + ) else: await callback.async_log_success_event( kwargs=self.model_call_details, @@ -1372,7 +1179,7 @@ class Logging: start_time=start_time, end_time=end_time, ) - if callable(callback): # custom logger functions + if callable(callback): # custom logger functions print_verbose(f"Async success callbacks: async_log_event") await customLogger.async_log_event( kwargs=self.model_call_details, @@ -1380,7 +1187,7 @@ class Logging: start_time=start_time, end_time=end_time, print_verbose=print_verbose, - callback_func=callback, + callback_func=callback ) if callback == "dynamodb": global dynamoLogger @@ -1388,22 +1195,16 @@ class Logging: dynamoLogger = DyanmoDBLogger() if self.stream: if "complete_streaming_response" in self.model_call_details: - print_verbose( - "DynamoDB Logger: Got Stream Event - Completed Stream Response" - ) + print_verbose("DynamoDB Logger: Got Stream Event - Completed Stream Response") await dynamoLogger._async_log_event( kwargs=self.model_call_details, - response_obj=self.model_call_details[ - "complete_streaming_response" - ], + response_obj=self.model_call_details["complete_streaming_response"], start_time=start_time, end_time=end_time, - print_verbose=print_verbose, - ) - else: - print_verbose( - "DynamoDB Logger: Got Stream Event - No complete stream response as yet" + print_verbose=print_verbose ) + else: + print_verbose("DynamoDB Logger: Got Stream Event - No complete stream response as yet") else: await dynamoLogger._async_log_event( kwargs=self.model_call_details, @@ -1416,10 +1217,8 @@ class Logging: global langFuseLogger print_verbose("reaches langfuse for logging!") kwargs = {} - for k, v in self.model_call_details.items(): - if ( - k != "original_response" - ): # copy.deepcopy raises errors as this could be a coroutine + for k, v in self.model_call_details.items(): + if k != "original_response": # copy.deepcopy raises errors as this could be a coroutine kwargs[k] = v # this only logs streaming once, complete_streaming_response exists i.e when stream ends if self.stream: @@ -1432,21 +1231,18 @@ class Logging: langFuseLogger = LangFuseLogger() await langFuseLogger._async_log_event( kwargs=kwargs, - user_id=self.model_call_details.get("user", "default"), response_obj=result, start_time=start_time, end_time=end_time, print_verbose=print_verbose, ) - except: + except: print_verbose( f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}" ) pass - def _failure_handler_helper_fn( - self, exception, traceback_exception, start_time=None, end_time=None - ): + def _failure_handler_helper_fn(self, exception, traceback_exception, start_time=None, end_time=None): if start_time is None: start_time = self.start_time if end_time is None: @@ -1463,58 +1259,49 @@ class Logging: self.model_call_details.setdefault("original_response", None) return start_time, end_time - def failure_handler( - self, exception, traceback_exception, start_time=None, end_time=None - ): - print_verbose(f"Logging Details LiteLLM-Failure Call") - try: - start_time, end_time = self._failure_handler_helper_fn( - exception=exception, - traceback_exception=traceback_exception, - start_time=start_time, - end_time=end_time, + def failure_handler(self, exception, traceback_exception, start_time=None, end_time=None): + print_verbose( + f"Logging Details LiteLLM-Failure Call" ) - result = None # result sent to all loggers, init this to None incase it's not created + try: + start_time, end_time = self._failure_handler_helper_fn(exception=exception, traceback_exception=traceback_exception, start_time=start_time, end_time=end_time) + result = None # result sent to all loggers, init this to None incase it's not created for callback in litellm.failure_callback: try: if callback == "lite_debugger": - print_verbose("reaches lite_debugger for logging!") - print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") - result = { - "model": self.model, - "created": time.time(), - "error": traceback_exception, - "usage": { - "prompt_tokens": prompt_token_calculator( - self.model, messages=self.messages - ), - "completion_tokens": 0, - }, - } - liteDebuggerClient.log_event( - model=self.model, - messages=self.messages, - end_user=self.model_call_details.get("user", "default"), - response_obj=result, - start_time=start_time, - end_time=end_time, - litellm_call_id=self.litellm_call_id, - print_verbose=print_verbose, - call_type=self.call_type, - stream=self.stream, - ) + print_verbose("reaches lite_debugger for logging!") + print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") + result = { + "model": self.model, + "created": time.time(), + "error": traceback_exception, + "usage": { + "prompt_tokens": prompt_token_calculator( + self.model, messages=self.messages + ), + "completion_tokens": 0, + }, + } + liteDebuggerClient.log_event( + model=self.model, + messages=self.messages, + end_user=self.model_call_details.get("user", "default"), + response_obj=result, + start_time=start_time, + end_time=end_time, + litellm_call_id=self.litellm_call_id, + print_verbose=print_verbose, + call_type = self.call_type, + stream = self.stream, + ) elif callback == "llmonitor": print_verbose("reaches llmonitor for logging error!") model = self.model input = self.model_call_details["input"] - - type = ( - "embed" - if self.call_type == CallTypes.embedding.value - else "llm" - ) + + type = "embed" if self.call_type == CallTypes.embedding.value else "llm" llmonitorLogger.log_event( type=type, @@ -1533,29 +1320,17 @@ class Logging: if capture_exception: capture_exception(exception) else: - print_verbose( - f"capture exception not initialized: {capture_exception}" - ) - elif callable(callback): # custom logger functions + print_verbose(f"capture exception not initialized: {capture_exception}") + elif callable(callback): # custom logger functions customLogger.log_event( kwargs=self.model_call_details, response_obj=result, start_time=start_time, end_time=end_time, print_verbose=print_verbose, - callback_func=callback, + callback_func=callback ) - elif ( - isinstance(callback, CustomLogger) - and self.model_call_details.get("litellm_params", {}).get( - "acompletion", False - ) - == False - and self.model_call_details.get("litellm_params", {}).get( - "aembedding", False - ) - == False - ): # custom logger class + elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False and self.model_call_details.get("litellm_params", {}).get("aembedding", False) == False: # custom logger class callback.log_failure_event( start_time=start_time, end_time=end_time, @@ -1577,43 +1352,37 @@ class Logging: ) pass - async def async_failure_handler( - self, exception, traceback_exception, start_time=None, end_time=None - ): + async def async_failure_handler(self, exception, traceback_exception, start_time=None, end_time=None): """ Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. """ - start_time, end_time = self._failure_handler_helper_fn( - exception=exception, - traceback_exception=traceback_exception, - start_time=start_time, - end_time=end_time, - ) - result = None # result sent to all loggers, init this to None incase it's not created + start_time, end_time = self._failure_handler_helper_fn(exception=exception, traceback_exception=traceback_exception, start_time=start_time, end_time=end_time) + result = None # result sent to all loggers, init this to None incase it's not created for callback in litellm._async_failure_callback: - try: - if isinstance(callback, CustomLogger): # custom logger class + try: + if isinstance(callback, CustomLogger): # custom logger class await callback.async_log_failure_event( kwargs=self.model_call_details, response_obj=result, start_time=start_time, end_time=end_time, ) - if callable(callback): # custom logger functions + if callable(callback): # custom logger functions await customLogger.async_log_event( kwargs=self.model_call_details, response_obj=result, start_time=start_time, end_time=end_time, print_verbose=print_verbose, - callback_func=callback, - ) - except Exception as e: + callback_func=callback + ) + except Exception as e: print_verbose( f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}" ) + def exception_logging( additional_args={}, logger_fn=None, @@ -1646,59 +1415,53 @@ def exception_logging( ####### RULES ################### - -class Rules: +class Rules: """ Fail calls based on the input or llm api output - Example usage: - import litellm - def my_custom_rule(input): # receives the model response - if "i don't think i can answer" in input: # trigger fallback if the model refuses to answer - return False - return True - + Example usage: + import litellm + def my_custom_rule(input): # receives the model response + if "i don't think i can answer" in input: # trigger fallback if the model refuses to answer + return False + return True + litellm.post_call_rules = [my_custom_rule] # have these be functions that can be called to fail a call - response = litellm.completion(model="gpt-3.5-turbo", messages=[{"role": "user", - "content": "Hey, how's it going?"}], fallbacks=["openrouter/mythomax"]) + response = litellm.completion(model="gpt-3.5-turbo", messages=[{"role": "user", + "content": "Hey, how's it going?"}], fallbacks=["openrouter/mythomax"]) """ - def __init__(self) -> None: pass - def pre_call_rules(self, input: str, model: str): - for rule in litellm.pre_call_rules: - if callable(rule): + def pre_call_rules(self, input: str, model: str): + for rule in litellm.pre_call_rules: + if callable(rule): decision = rule(input) if decision is False: - raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore - return True + raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore + return True - def post_call_rules(self, input: str, model: str): - for rule in litellm.post_call_rules: - if callable(rule): + def post_call_rules(self, input: str, model: str): + for rule in litellm.post_call_rules: + if callable(rule): decision = rule(input) if decision is False: - raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore - return True - + raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore + return True ####### CLIENT ################### # make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking def client(original_function): global liteDebuggerClient, get_all_keys rules_obj = Rules() - def function_setup( start_time, *args, **kwargs ): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc. try: global callback_list, add_breadcrumb, user_logger_fn, Logging function_id = kwargs["id"] if "id" in kwargs else None - if litellm.use_client or ( - "use_client" in kwargs and kwargs["use_client"] == True - ): + if litellm.use_client or ("use_client" in kwargs and kwargs["use_client"] == True): print_verbose(f"litedebugger initialized") if "lite_debugger" not in litellm.input_callback: litellm.input_callback.append("lite_debugger") @@ -1706,8 +1469,8 @@ def client(original_function): litellm.success_callback.append("lite_debugger") if "lite_debugger" not in litellm.failure_callback: litellm.failure_callback.append("lite_debugger") - if len(litellm.callbacks) > 0: - for callback in litellm.callbacks: + if len(litellm.callbacks) > 0: + for callback in litellm.callbacks: if callback not in litellm.input_callback: litellm.input_callback.append(callback) if callback not in litellm.success_callback: @@ -1718,9 +1481,7 @@ def client(original_function): litellm._async_success_callback.append(callback) if callback not in litellm._async_failure_callback: litellm._async_failure_callback.append(callback) - print_verbose( - f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}" - ) + print_verbose(f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}") if ( len(litellm.input_callback) > 0 or len(litellm.success_callback) > 0 @@ -1733,7 +1494,10 @@ def client(original_function): + litellm.failure_callback ) ) - set_callbacks(callback_list=callback_list, function_id=function_id) + set_callbacks( + callback_list=callback_list, + function_id=function_id + ) ## ASYNC CALLBACKS if len(litellm.input_callback) > 0: removed_async_items = [] @@ -1746,10 +1510,10 @@ def client(original_function): for index in reversed(removed_async_items): litellm.input_callback.pop(index) - if len(litellm.success_callback) > 0: + if len(litellm.success_callback) > 0: removed_async_items = [] - for index, callback in enumerate(litellm.success_callback): - if inspect.iscoroutinefunction(callback): + for index, callback in enumerate(litellm.success_callback): + if inspect.iscoroutinefunction(callback): litellm._async_success_callback.append(callback) removed_async_items.append(index) elif callback == "dynamodb": @@ -1757,9 +1521,7 @@ def client(original_function): # we only support async dynamo db logging for acompletion/aembedding since that's used on proxy litellm._async_success_callback.append(callback) removed_async_items.append(index) - elif callback == "langfuse" and inspect.iscoroutinefunction( - original_function - ): + elif callback == "langfuse" and inspect.iscoroutinefunction(original_function): # use async success callback for langfuse if this is litellm.acompletion(). Streaming logging does not work otherwise litellm._async_success_callback.append(callback) removed_async_items.append(index) @@ -1767,11 +1529,11 @@ def client(original_function): # Pop the async items from success_callback in reverse order to avoid index issues for index in reversed(removed_async_items): litellm.success_callback.pop(index) - - if len(litellm.failure_callback) > 0: + + if len(litellm.failure_callback) > 0: removed_async_items = [] - for index, callback in enumerate(litellm.failure_callback): - if inspect.iscoroutinefunction(callback): + for index, callback in enumerate(litellm.failure_callback): + if inspect.iscoroutinefunction(callback): litellm._async_failure_callback.append(callback) removed_async_items.append(index) @@ -1791,70 +1553,38 @@ def client(original_function): # INIT LOGGER - for user-specified integrations model = args[0] if len(args) > 0 else kwargs.get("model", None) call_type = original_function.__name__ - if ( - call_type == CallTypes.completion.value - or call_type == CallTypes.acompletion.value - ): + if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value: messages = None if len(args) > 1: - messages = args[1] + messages = args[1] elif kwargs.get("messages", None): messages = kwargs["messages"] - ### PRE-CALL RULES ### - if ( - isinstance(messages, list) - and len(messages) > 0 - and isinstance(messages[0], dict) - and "content" in messages[0] - ): - rules_obj.pre_call_rules( - input="".join( - m["content"] - for m in messages - if isinstance(m["content"], str) - ), - model=model, - ) - elif ( - call_type == CallTypes.embedding.value - or call_type == CallTypes.aembedding.value - ): + ### PRE-CALL RULES ### + if isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], dict) and "content" in messages[0]: + rules_obj.pre_call_rules(input="".join(m["content"] for m in messages if isinstance(m["content"], str)), model=model) + elif call_type == CallTypes.embedding.value or call_type == CallTypes.aembedding.value: messages = args[1] if len(args) > 1 else kwargs["input"] - elif ( - call_type == CallTypes.image_generation.value - or call_type == CallTypes.aimage_generation.value - ): + elif call_type == CallTypes.image_generation.value or call_type == CallTypes.aimage_generation.value: messages = args[0] if len(args) > 0 else kwargs["prompt"] stream = True if "stream" in kwargs and kwargs["stream"] == True else False - logging_obj = Logging( - model=model, - messages=messages, - stream=stream, - litellm_call_id=kwargs["litellm_call_id"], - function_id=function_id, - call_type=call_type, - start_time=start_time, - ) + logging_obj = Logging(model=model, messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], function_id=function_id, call_type=call_type, start_time=start_time) return logging_obj - except Exception as e: + except Exception as e: import logging - - logging.debug( - f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}" - ) + logging.debug(f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}") raise e - + def post_call_processing(original_response, model): - try: - call_type = original_function.__name__ - if ( - call_type == CallTypes.completion.value - or call_type == CallTypes.acompletion.value - ): - model_response = original_response["choices"][0]["message"]["content"] - ### POST-CALL RULES ### - rules_obj.post_call_rules(input=model_response, model=model) - except Exception as e: + try: + if original_response is None: + pass + else: + call_type = original_function.__name__ + if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value: + model_response = original_response['choices'][0]['message']['content'] + ### POST-CALL RULES ### + rules_obj.post_call_rules(input=model_response, model=model) + except Exception as e: raise e def crash_reporting(*args, **kwargs): @@ -1887,7 +1617,8 @@ def client(original_function): try: model = args[0] if len(args) > 0 else kwargs["model"] except: - call_type = original_function.__name__ + model = None + call_type = original_function.__name__ if call_type != CallTypes.image_generation.value: raise ValueError("model param not passed in.") @@ -1896,130 +1627,86 @@ def client(original_function): logging_obj = function_setup(start_time, *args, **kwargs) kwargs["litellm_logging_obj"] = logging_obj - # [OPTIONAL] CHECK BUDGET + # [OPTIONAL] CHECK BUDGET if litellm.max_budget: if litellm._current_cost > litellm.max_budget: - raise BudgetExceededError( - current_cost=litellm._current_cost, - max_budget=litellm.max_budget, - ) + raise BudgetExceededError(current_cost=litellm._current_cost, max_budget=litellm.max_budget) # [OPTIONAL] CHECK CACHE - print_verbose( - f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}" - ) - # if caching is false, don't run this - if ( - kwargs.get("caching", None) is None and litellm.cache is not None - ) or kwargs.get( - "caching", False - ) == True: # allow users to control returning cached responses from the completion function + print_verbose(f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}") + # if caching is false, don't run this + if (kwargs.get("caching", None) is None and litellm.cache is not None) or kwargs.get("caching", False) == True: # allow users to control returning cached responses from the completion function # checking cache print_verbose(f"INSIDE CHECKING CACHE") - if ( - litellm.cache is not None - and str(original_function.__name__) - in litellm.cache.supported_call_types - ): + if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types: print_verbose(f"Checking Cache") preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) - kwargs[ - "preset_cache_key" - ] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key + kwargs["preset_cache_key"] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key cached_result = litellm.cache.get_cache(*args, **kwargs) if cached_result != None: - if "detail" in cached_result: - # implies an error occurred + if "detail" in cached_result: + # implies an error occurred pass - else: + else: call_type = original_function.__name__ - print_verbose( - f"Cache Response Object routing: call_type - {call_type}; cached_result instace: {type(cached_result)}" - ) - if call_type == CallTypes.completion.value and isinstance( - cached_result, dict - ): - return convert_to_model_response_object( - response_object=cached_result, - model_response_object=ModelResponse(), - stream=kwargs.get("stream", False), - ) - elif call_type == CallTypes.embedding.value and isinstance( - cached_result, dict - ): - return convert_to_model_response_object( - response_object=cached_result, - response_type="embedding", - ) - else: + print_verbose(f"Cache Response Object routing: call_type - {call_type}; cached_result instace: {type(cached_result)}") + if call_type == CallTypes.completion.value and isinstance(cached_result, dict): + return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse(), stream = kwargs.get("stream", False)) + elif call_type == CallTypes.embedding.value and isinstance(cached_result, dict): + return convert_to_model_response_object(response_object=cached_result, response_type="embedding") + else: return cached_result # MODEL CALL result = original_function(*args, **kwargs) end_time = datetime.datetime.now() if "stream" in kwargs and kwargs["stream"] == True: # TODO: Add to cache for streaming - if ( - "complete_response" in kwargs - and kwargs["complete_response"] == True - ): + if "complete_response" in kwargs and kwargs["complete_response"] == True: chunks = [] for idx, chunk in enumerate(result): chunks.append(chunk) - return litellm.stream_chunk_builder( - chunks, messages=kwargs.get("messages", None) - ) - else: + return litellm.stream_chunk_builder(chunks, messages=kwargs.get("messages", None)) + else: return result - elif "acompletion" in kwargs and kwargs["acompletion"] == True: + elif "acompletion" in kwargs and kwargs["acompletion"] == True: return result - elif "aembedding" in kwargs and kwargs["aembedding"] == True: + elif "aembedding" in kwargs and kwargs["aembedding"] == True: return result - - ### POST-CALL RULES ### + elif "aimg_generation" in kwargs and kwargs["aimg_generation"] == True: + return result + + ### POST-CALL RULES ### post_call_processing(original_response=result, model=model or None) # [OPTIONAL] ADD TO CACHE - if ( - litellm.cache is not None - and str(original_function.__name__) - in litellm.cache.supported_call_types - ): + if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types: litellm.cache.add_cache(result, *args, **kwargs) # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated print_verbose(f"Wrapper: Completed Call, calling success_handler") - threading.Thread( - target=logging_obj.success_handler, args=(result, start_time, end_time) - ).start() + threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start() # RETURN RESULT - result._response_ms = ( - end_time - start_time - ).total_seconds() * 1000 # return response latency in ms like openai + result._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai return result except Exception as e: call_type = original_function.__name__ if call_type == CallTypes.completion.value: num_retries = ( - kwargs.get("num_retries", None) or litellm.num_retries or None - ) - litellm.num_retries = ( - None # set retries to None to prevent infinite loops - ) - context_window_fallback_dict = kwargs.get( - "context_window_fallback_dict", {} + kwargs.get("num_retries", None) + or litellm.num_retries + or None ) + litellm.num_retries = None # set retries to None to prevent infinite loops + context_window_fallback_dict = kwargs.get("context_window_fallback_dict", {}) - if num_retries: - if isinstance(e, openai.APIError) or isinstance(e, openai.Timeout): + if num_retries: + if (isinstance(e, openai.APIError) + or isinstance(e, openai.Timeout)): kwargs["num_retries"] = num_retries return litellm.completion_with_retries(*args, **kwargs) - elif ( - isinstance(e, litellm.exceptions.ContextWindowExceededError) - and context_window_fallback_dict - and model in context_window_fallback_dict - ): + elif isinstance(e, litellm.exceptions.ContextWindowExceededError) and context_window_fallback_dict and model in context_window_fallback_dict: if len(args) > 0: - args[0] = context_window_fallback_dict[model] + args[0] = context_window_fallback_dict[model] else: kwargs["model"] = context_window_fallback_dict[model] return original_function(*args, **kwargs) @@ -2028,9 +1715,7 @@ def client(original_function): end_time = datetime.datetime.now() # LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated if logging_obj: - logging_obj.failure_handler( - e, traceback_exception, start_time, end_time - ) # DO NOT MAKE THREADED - router retry fallback relies on this! + logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # DO NOT MAKE THREADED - router retry fallback relies on this! my_thread = threading.Thread( target=handle_failure, args=(e, traceback_exception, start_time, end_time, args, kwargs), @@ -2042,8 +1727,8 @@ def client(original_function): ): # make it easy to get to the debugger logs if you've initialized it e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}" raise e - - async def wrapper_async(*args, **kwargs): + + async def wrapper_async(*args, **kwargs): start_time = datetime.datetime.now() result = None logging_obj = kwargs.get("litellm_logging_obj", None) @@ -2054,215 +1739,115 @@ def client(original_function): model = args[0] if len(args) > 0 else kwargs["model"] except: raise ValueError("model param not passed in.") - - try: + + try: if logging_obj is None: logging_obj = function_setup(start_time, *args, **kwargs) kwargs["litellm_logging_obj"] = logging_obj - # [OPTIONAL] CHECK BUDGET + # [OPTIONAL] CHECK BUDGET if litellm.max_budget: if litellm._current_cost > litellm.max_budget: - raise BudgetExceededError( - current_cost=litellm._current_cost, - max_budget=litellm.max_budget, - ) + raise BudgetExceededError(current_cost=litellm._current_cost, max_budget=litellm.max_budget) # [OPTIONAL] CHECK CACHE print_verbose(f"litellm.cache: {litellm.cache}") - print_verbose( - f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}" - ) - # if caching is false, don't run this - if ( - kwargs.get("caching", None) is None and litellm.cache is not None - ) or kwargs.get( - "caching", False - ) == True: # allow users to control returning cached responses from the completion function + print_verbose(f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}") + # if caching is false, don't run this + if (kwargs.get("caching", None) is None and litellm.cache is not None) or kwargs.get("caching", False) == True: # allow users to control returning cached responses from the completion function # checking cache print_verbose(f"INSIDE CHECKING CACHE") - if ( - litellm.cache is not None - and str(original_function.__name__) - in litellm.cache.supported_call_types - ): + if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types: print_verbose(f"Checking Cache") cached_result = litellm.cache.get_cache(*args, **kwargs) if cached_result != None: print_verbose(f"Cache Hit!") call_type = original_function.__name__ - if call_type == CallTypes.acompletion.value and isinstance( - cached_result, dict - ): + if call_type == CallTypes.acompletion.value and isinstance(cached_result, dict): if kwargs.get("stream", False) == True: cached_result = convert_to_streaming_response_async( response_object=cached_result, ) else: - cached_result = convert_to_model_response_object( - response_object=cached_result, - model_response_object=ModelResponse(), - ) - elif call_type == CallTypes.aembedding.value and isinstance( - cached_result, dict - ): - cached_result = convert_to_model_response_object( - response_object=cached_result, - model_response_object=EmbeddingResponse(), - response_type="embedding", - ) - # LOG SUCCESS + cached_result = convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse()) + elif call_type == CallTypes.aembedding.value and isinstance(cached_result, dict): + cached_result = convert_to_model_response_object(response_object=cached_result, model_response_object=EmbeddingResponse(), response_type="embedding") + # LOG SUCCESS cache_hit = True end_time = datetime.datetime.now() - ( - model, - custom_llm_provider, - dynamic_api_key, - api_base, - ) = litellm.get_llm_provider( - model=model, - custom_llm_provider=kwargs.get("custom_llm_provider", None), - api_base=kwargs.get("api_base", None), - api_key=kwargs.get("api_key", None), - ) - print_verbose( - f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" - ) - logging_obj.update_environment_variables( - model=model, - user=kwargs.get("user", None), - optional_params={}, - litellm_params={ - "logger_fn": kwargs.get("logger_fn", None), - "acompletion": True, - "metadata": kwargs.get("metadata", {}), - "model_info": kwargs.get("model_info", {}), - "proxy_server_request": kwargs.get( - "proxy_server_request", None - ), - "preset_cache_key": kwargs.get( - "preset_cache_key", None - ), - "stream_response": kwargs.get("stream_response", {}), - }, - input=kwargs.get("messages", ""), - api_key=kwargs.get("api_key", None), - original_response=str(cached_result), - additional_args=None, - stream=kwargs.get("stream", False), - ) - asyncio.create_task( - logging_obj.async_success_handler( - cached_result, start_time, end_time, cache_hit - ) - ) - threading.Thread( - target=logging_obj.success_handler, - args=(cached_result, start_time, end_time, cache_hit), - ).start() + model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(model=model, custom_llm_provider=kwargs.get('custom_llm_provider', None), api_base=kwargs.get('api_base', None), api_key=kwargs.get('api_key', None)) + print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}") + logging_obj.update_environment_variables(model=model, user=kwargs.get('user', None), optional_params={}, litellm_params={"logger_fn": kwargs.get('logger_fn', None), "acompletion": True, "metadata": kwargs.get("metadata", {}), "model_info": kwargs.get("model_info", {}), "proxy_server_request": kwargs.get("proxy_server_request", None), "preset_cache_key": kwargs.get("preset_cache_key", None), "stream_response": kwargs.get("stream_response", {})}, input=kwargs.get('messages', ""), api_key=kwargs.get('api_key', None), original_response=str(cached_result), additional_args=None, stream=kwargs.get('stream', False)) + asyncio.create_task(logging_obj.async_success_handler(cached_result, start_time, end_time, cache_hit)) + threading.Thread(target=logging_obj.success_handler, args=(cached_result, start_time, end_time, cache_hit)).start() return cached_result # MODEL CALL result = await original_function(*args, **kwargs) end_time = datetime.datetime.now() if "stream" in kwargs and kwargs["stream"] == True: - if ( - "complete_response" in kwargs - and kwargs["complete_response"] == True - ): + if "complete_response" in kwargs and kwargs["complete_response"] == True: chunks = [] for idx, chunk in enumerate(result): chunks.append(chunk) - return litellm.stream_chunk_builder( - chunks, messages=kwargs.get("messages", None) - ) - else: + return litellm.stream_chunk_builder(chunks, messages=kwargs.get("messages", None)) + else: return result - - ### POST-CALL RULES ### + + ### POST-CALL RULES ### post_call_processing(original_response=result, model=model) # [OPTIONAL] ADD TO CACHE - if ( - litellm.cache is not None - and str(original_function.__name__) - in litellm.cache.supported_call_types - ): - if isinstance(result, litellm.ModelResponse) or isinstance( - result, litellm.EmbeddingResponse - ): - asyncio.create_task( - litellm.cache._async_add_cache(result.json(), *args, **kwargs) - ) + if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types: + if isinstance(result, litellm.ModelResponse) or isinstance(result, litellm.EmbeddingResponse): + asyncio.create_task(litellm.cache._async_add_cache(result.json(), *args, **kwargs)) else: - asyncio.create_task( - litellm.cache._async_add_cache(result, *args, **kwargs) - ) + asyncio.create_task(litellm.cache._async_add_cache(result, *args, **kwargs)) # LOG SUCCESS - handle streaming success logging in the _next_ object - print_verbose( - f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" - ) - asyncio.create_task( - logging_obj.async_success_handler(result, start_time, end_time) - ) - threading.Thread( - target=logging_obj.success_handler, args=(result, start_time, end_time) - ).start() + print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}") + asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time)) + threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start() # RETURN RESULT if isinstance(result, ModelResponse): - result._response_ms = ( - end_time - start_time - ).total_seconds() * 1000 # return response latency in ms like openai + result._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai return result - except Exception as e: + except Exception as e: traceback_exception = traceback.format_exc() crash_reporting(*args, **kwargs, exception=traceback_exception) end_time = datetime.datetime.now() if logging_obj: try: - logging_obj.failure_handler( - e, traceback_exception, start_time, end_time - ) # DO NOT MAKE THREADED - router retry fallback relies on this! - except Exception as e: + logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # DO NOT MAKE THREADED - router retry fallback relies on this! + except Exception as e: raise e try: - await logging_obj.async_failure_handler( - e, traceback_exception, start_time, end_time - ) + await logging_obj.async_failure_handler(e, traceback_exception, start_time, end_time) except Exception as e: raise e - + call_type = original_function.__name__ if call_type == CallTypes.acompletion.value: num_retries = ( - kwargs.get("num_retries", None) or litellm.num_retries or None + kwargs.get("num_retries", None) + or litellm.num_retries + or None ) - litellm.num_retries = ( - None # set retries to None to prevent infinite loops - ) - context_window_fallback_dict = kwargs.get( - "context_window_fallback_dict", {} - ) - - if num_retries: - try: + litellm.num_retries = None # set retries to None to prevent infinite loops + context_window_fallback_dict = kwargs.get("context_window_fallback_dict", {}) + + if num_retries: + try: kwargs["num_retries"] = num_retries kwargs["original_function"] = original_function - if isinstance( - e, openai.RateLimitError - ): # rate limiting specific error + if (isinstance(e, openai.RateLimitError)): # rate limiting specific error kwargs["retry_strategy"] = "exponential_backoff_retry" - elif isinstance(e, openai.APIError): # generic api error + elif (isinstance(e, openai.APIError)): # generic api error kwargs["retry_strategy"] = "constant_retry" return await litellm.acompletion_with_retries(*args, **kwargs) except: pass - elif ( - isinstance(e, litellm.exceptions.ContextWindowExceededError) - and context_window_fallback_dict - and model in context_window_fallback_dict - ): + elif isinstance(e, litellm.exceptions.ContextWindowExceededError) and context_window_fallback_dict and model in context_window_fallback_dict: if len(args) > 0: - args[0] = context_window_fallback_dict[model] + args[0] = context_window_fallback_dict[model] else: kwargs["model"] = context_window_fallback_dict[model] return await original_function(*args, **kwargs) @@ -2276,7 +1861,6 @@ def client(original_function): else: return wrapper - ####### USAGE CALCULATOR ################ @@ -2284,10 +1868,7 @@ def client(original_function): # only used for together_computer LLMs def get_model_params_and_category(model_name): import re - - params_match = re.search( - r"(\d+b)", model_name - ) # catch all decimals like 3b, 70b, etc + params_match = re.search(r'(\d+b)', model_name) # catch all decimals like 3b, 70b, etc category = None if params_match != None: params_match = params_match.group(1) @@ -2308,36 +1889,30 @@ def get_model_params_and_category(model_name): return None - def get_replicate_completion_pricing(completion_response=None, total_time=0.0): # see https://replicate.com/pricing a100_40gb_price_per_second_public = 0.001150 # for all litellm currently supported LLMs, almost all requests go to a100_80gb - a100_80gb_price_per_second_public = ( - 0.001400 # assume all calls sent to A100 80GB for now - ) + a100_80gb_price_per_second_public = 0.001400 # assume all calls sent to A100 80GB for now if total_time == 0.0: - start_time = completion_response["created"] + start_time = completion_response['created'] end_time = completion_response["ended"] total_time = end_time - start_time - return a100_80gb_price_per_second_public * total_time + return a100_80gb_price_per_second_public*total_time -def _select_tokenizer(model: str): - # cohere +def _select_tokenizer(model: str): + # cohere import pkg_resources - if model in litellm.cohere_models: tokenizer = Tokenizer.from_pretrained("Cohere/command-nightly") return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} - # anthropic + # anthropic elif model in litellm.anthropic_models: # Read the JSON file - filename = pkg_resources.resource_filename( - __name__, "llms/tokenizers/anthropic_tokenizer.json" - ) - with open(filename, "r") as f: + filename = pkg_resources.resource_filename(__name__, 'llms/tokenizers/anthropic_tokenizer.json') + with open(filename, 'r') as f: json_data = json.load(f) # Decode the JSON data from utf-8 json_data_decoded = json.dumps(json_data, ensure_ascii=False) @@ -2346,16 +1921,15 @@ def _select_tokenizer(model: str): # load tokenizer tokenizer = Tokenizer.from_str(json_str) return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} - # llama2 - elif "llama-2" in model.lower(): + # llama2 + elif "llama-2" in model.lower(): tokenizer = Tokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} # default - tiktoken - else: + else: return {"type": "openai_tokenizer", "tokenizer": encoding} - -def encode(model: str, text: str): +def encode(model: str, text: str): """ Encodes the given text using the specified model. @@ -2370,18 +1944,12 @@ def encode(model: str, text: str): enc = tokenizer_json["tokenizer"].encode(text) return enc - -def decode(model: str, tokens: List[int]): +def decode(model: str, tokens: List[int]): tokenizer_json = _select_tokenizer(model=model) dec = tokenizer_json["tokenizer"].decode(tokens) return dec - -def openai_token_counter( - messages: Optional[list] = None, - model="gpt-3.5-turbo-0613", - text: Optional[str] = None, -): +def openai_token_counter(messages: Optional[list]=None, model="gpt-3.5-turbo-0613", text: Optional[str]= None): """ Return the number of tokens used by a list of messages. @@ -2393,9 +1961,7 @@ def openai_token_counter( print_verbose("Warning: model not found. Using cl100k_base encoding.") encoding = tiktoken.get_encoding("cl100k_base") if model == "gpt-3.5-turbo-0301": - tokens_per_message = ( - 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n - ) + tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n tokens_per_name = -1 # if there's a name, the role is omitted elif model in litellm.open_ai_chat_completion_models: tokens_per_message = 3 @@ -2406,9 +1972,9 @@ def openai_token_counter( ) num_tokens = 0 - if text: + if text: num_tokens = len(encoding.encode(text, disallowed_special=())) - elif messages: + elif messages: for message in messages: num_tokens += tokens_per_message for key, value in message.items(): @@ -2418,8 +1984,7 @@ def openai_token_counter( num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> return num_tokens - -def token_counter(model="", text=None, messages: Optional[List] = None): +def token_counter(model="", text=None, messages: Optional[List] = None): """ Count the number of tokens in a given text using a specified model. @@ -2435,24 +2000,24 @@ def token_counter(model="", text=None, messages: Optional[List] = None): if text == None: if messages is not None: print_verbose(f"token_counter messages received: {messages}") - text = "" - for message in messages: + text = "" + for message in messages: if message.get("content", None): text += message["content"] - if "tool_calls" in message: - for tool_call in message["tool_calls"]: - if "function" in tool_call: - function_arguments = tool_call["function"]["arguments"] + if 'tool_calls' in message: + for tool_call in message['tool_calls']: + if 'function' in tool_call: + function_arguments = tool_call['function']['arguments'] text += function_arguments else: raise ValueError("text and messages cannot both be None") num_tokens = 0 if model is not None: tokenizer_json = _select_tokenizer(model=model) - if tokenizer_json["type"] == "huggingface_tokenizer": + if tokenizer_json["type"] == "huggingface_tokenizer": enc = tokenizer_json["tokenizer"].encode(text) num_tokens = len(enc.ids) - elif tokenizer_json["type"] == "openai_tokenizer": + elif tokenizer_json["type"] == "openai_tokenizer": if model in litellm.open_ai_chat_completion_models: num_tokens = openai_token_counter(text=text, model=model) else: @@ -2471,7 +2036,7 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0): model (str): The name of the model to use. Default is "" prompt_tokens (int): The number of tokens in the prompt. completion_tokens (int): The number of tokens in the completion. - + Returns: tuple: A tuple containing the cost in USD dollars for prompt tokens and completion tokens, respectively. """ @@ -2483,7 +2048,7 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0): azure_llms = { "gpt-35-turbo": "azure/gpt-3.5-turbo", "gpt-35-turbo-16k": "azure/gpt-3.5-turbo-16k", - "gpt-35-turbo-instruct": "azure/gpt-3.5-turbo-instruct", + "gpt-35-turbo-instruct": "azure/gpt-3.5-turbo-instruct" } if model in model_cost_ref: prompt_tokens_cost_usd_dollar = ( @@ -2499,8 +2064,7 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0): model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens ) completion_tokens_cost_usd_dollar = ( - model_cost_ref["ft:gpt-3.5-turbo"]["output_cost_per_token"] - * completion_tokens + model_cost_ref["ft:gpt-3.5-turbo"]["output_cost_per_token"] * completion_tokens ) return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar elif model in azure_llms: @@ -2528,22 +2092,22 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0): def completion_cost( - completion_response=None, - model=None, - prompt="", - messages: List = [], - completion="", - total_time=0.0, # used for replicate -): + completion_response=None, + model=None, + prompt="", + messages: List = [], + completion="", + total_time=0.0, # used for replicate + ): """ Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm. Parameters: completion_response (litellm.ModelResponses): [Required] The response received from a LiteLLM completion request. - + [OPTIONAL PARAMS] model (str): Optional. The name of the language model used in the completion calls - prompt (str): Optional. The input prompt passed to the llm + prompt (str): Optional. The input prompt passed to the llm completion (str): Optional. The output completion text from the llm total_time (float): Optional. (Only used for Replicate LLMs) The total time used for the request in seconds @@ -2566,46 +2130,41 @@ def completion_cost( completion_tokens = 0 if completion_response is not None: # get input/output tokens from completion_response - prompt_tokens = completion_response["usage"]["prompt_tokens"] - completion_tokens = completion_response["usage"]["completion_tokens"] - model = ( - model or completion_response["model"] - ) # check if user passed an override for model, if it's none check completion_response['model'] + prompt_tokens = completion_response['usage']['prompt_tokens'] + completion_tokens = completion_response['usage']['completion_tokens'] + model = model or completion_response['model'] # check if user passed an override for model, if it's none check completion_response['model'] else: if len(messages) > 0: prompt_tokens = token_counter(model=model, messages=messages) - elif len(prompt) > 0: + elif len(prompt) > 0: prompt_tokens = token_counter(model=model, text=prompt) completion_tokens = token_counter(model=model, text=completion) - + # Calculate cost based on prompt_tokens, completion_tokens if "togethercomputer" in model: # together ai prices based on size of llm - # get_model_params_and_category takes a model name and returns the category of LLM size it is in model_prices_and_context_window.json + # get_model_params_and_category takes a model name and returns the category of LLM size it is in model_prices_and_context_window.json model = get_model_params_and_category(model) # replicate llms are calculate based on time for request running # see https://replicate.com/pricing - elif model in litellm.replicate_models or "replicate" in model: + elif ( + model in litellm.replicate_models or + "replicate" in model + ): return get_replicate_completion_pricing(completion_response, total_time) - ( - prompt_tokens_cost_usd_dollar, - completion_tokens_cost_usd_dollar, - ) = cost_per_token( - model=model, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, + prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = cost_per_token( + model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens ) return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar except: - return 0.0 # this should not block a users execution path - + return 0.0 # this should not block a users execution path ####### HELPER FUNCTIONS ################ -def register_model(model_cost: Union[str, dict]): +def register_model(model_cost: Union[str, dict]): """ - Register new / Override existing models (and their pricing) to specific providers. + Register new / Override existing models (and their pricing) to specific providers. Provide EITHER a model cost dictionary or a url to a hosted json blob - Example usage: + Example usage: model_cost_dict = { "gpt-4": { "max_tokens": 8192, @@ -2617,60 +2176,59 @@ def register_model(model_cost: Union[str, dict]): } """ loaded_model_cost = {} - if isinstance(model_cost, dict): + if isinstance(model_cost, dict): loaded_model_cost = model_cost - elif isinstance(model_cost, str): + elif isinstance(model_cost, str): loaded_model_cost = litellm.get_model_cost_map(url=model_cost) for key, value in loaded_model_cost.items(): ## override / add new keys to the existing model cost dictionary if key in litellm.model_cost: - for k, v in loaded_model_cost[key].items(): + for k,v in loaded_model_cost[key].items(): litellm.model_cost[key][k] = v # add new model names to provider lists - if value.get("litellm_provider") == "openai": + if value.get('litellm_provider') == 'openai': if key not in litellm.open_ai_chat_completion_models: litellm.open_ai_chat_completion_models.append(key) - elif value.get("litellm_provider") == "text-completion-openai": + elif value.get('litellm_provider') == 'text-completion-openai': if key not in litellm.open_ai_text_completion_models: litellm.open_ai_text_completion_models.append(key) - elif value.get("litellm_provider") == "cohere": + elif value.get('litellm_provider') == 'cohere': if key not in litellm.cohere_models: litellm.cohere_models.append(key) - elif value.get("litellm_provider") == "anthropic": + elif value.get('litellm_provider') == 'anthropic': if key not in litellm.anthropic_models: litellm.anthropic_models.append(key) - elif value.get("litellm_provider") == "openrouter": - split_string = key.split("/", 1) + elif value.get('litellm_provider') == 'openrouter': + split_string = key.split('/', 1) if key not in litellm.openrouter_models: litellm.openrouter_models.append(split_string[1]) - elif value.get("litellm_provider") == "vertex_ai-text-models": + elif value.get('litellm_provider') == 'vertex_ai-text-models': if key not in litellm.vertex_text_models: litellm.vertex_text_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-code-text-models": + elif value.get('litellm_provider') == 'vertex_ai-code-text-models': if key not in litellm.vertex_code_text_models: litellm.vertex_code_text_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-chat-models": + elif value.get('litellm_provider') == 'vertex_ai-chat-models': if key not in litellm.vertex_chat_models: litellm.vertex_chat_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-code-chat-models": + elif value.get('litellm_provider') == 'vertex_ai-code-chat-models': if key not in litellm.vertex_code_chat_models: litellm.vertex_code_chat_models.append(key) - elif value.get("litellm_provider") == "ai21": + elif value.get('litellm_provider') == 'ai21': if key not in litellm.ai21_models: litellm.ai21_models.append(key) - elif value.get("litellm_provider") == "nlp_cloud": + elif value.get('litellm_provider') == 'nlp_cloud': if key not in litellm.nlp_cloud_models: litellm.nlp_cloud_models.append(key) - elif value.get("litellm_provider") == "aleph_alpha": + elif value.get('litellm_provider') == 'aleph_alpha': if key not in litellm.aleph_alpha_models: litellm.aleph_alpha_models.append(key) - elif value.get("litellm_provider") == "bedrock": + elif value.get('litellm_provider') == 'bedrock': if key not in litellm.bedrock_models: litellm.bedrock_models.append(key) return model_cost - def get_litellm_params( api_key=None, force_timeout=600, @@ -2689,7 +2247,7 @@ def get_litellm_params( model_info=None, proxy_server_request=None, acompletion=None, - preset_cache_key=None, + preset_cache_key = None ): litellm_params = { "acompletion": acompletion, @@ -2706,21 +2264,20 @@ def get_litellm_params( "model_info": model_info, "proxy_server_request": proxy_server_request, "preset_cache_key": preset_cache_key, - "stream_response": {}, # litellm_call_id: ModelResponse Dict + "stream_response": {} # litellm_call_id: ModelResponse Dict } return litellm_params - def get_optional_params_image_gen( - n: Optional[int] = None, - quality: Optional[str] = None, - response_format: Optional[str] = None, - size: Optional[str] = None, - style: Optional[str] = None, - user: Optional[str] = None, - custom_llm_provider: Optional[str] = None, - **kwargs, + n: Optional[int]=None, + quality: Optional[str]=None, + response_format: Optional[str]=None, + size: Optional[str]=None, + style: Optional[str]=None, + user: Optional[str]=None, + custom_llm_provider: Optional[str]=None, + **kwargs ): # retrieve all parameters passed to the function passed_params = locals() @@ -2728,44 +2285,38 @@ def get_optional_params_image_gen( special_params = passed_params.pop("kwargs") for k, v in special_params.items(): passed_params[k] = v - + default_params = { - "n": None, - "quality": None, - "response_format": None, - "size": None, + "n": None, + "quality" : None, + "response_format" : None, + "size": None, "style": None, "user": None, } - non_default_params = { - k: v - for k, v in passed_params.items() - if (k in default_params and v != default_params[k]) - } + non_default_params = {k: v for k, v in passed_params.items() if (k in default_params and v != default_params[k])} ## raise exception if non-default value passed for non-openai/azure embedding calls if custom_llm_provider != "openai" and custom_llm_provider != "azure": - if len(non_default_params.keys()) > 0: - if litellm.drop_params is True: # drop the unsupported non-default values + if len(non_default_params.keys()) > 0: + if litellm.drop_params is True: # drop the unsupported non-default values keys = list(non_default_params.keys()) - for k in keys: + for k in keys: non_default_params.pop(k, None) return non_default_params - raise UnsupportedParamsError( - status_code=500, - message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.", - ) - + raise UnsupportedParamsError(status_code=500, message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.") + final_params = {**non_default_params, **kwargs} return final_params - + + def get_optional_params_embeddings( # 2 optional params - user=None, + user=None, encoding_format=None, custom_llm_provider="", - **kwargs, + **kwargs ): # retrieve all parameters passed to the function passed_params = locals() @@ -2773,31 +2324,26 @@ def get_optional_params_embeddings( special_params = passed_params.pop("kwargs") for k, v in special_params.items(): passed_params[k] = v - - default_params = {"user": None, "encoding_format": None} - - non_default_params = { - k: v - for k, v in passed_params.items() - if (k in default_params and v != default_params[k]) + + default_params = { + "user": None, + "encoding_format": None } + + non_default_params = {k: v for k, v in passed_params.items() if (k in default_params and v != default_params[k])} ## raise exception if non-default value passed for non-openai/azure embedding calls if custom_llm_provider != "openai" and custom_llm_provider != "azure": - if len(non_default_params.keys()) > 0: - if litellm.drop_params is True: # drop the unsupported non-default values + if len(non_default_params.keys()) > 0: + if litellm.drop_params is True: # drop the unsupported non-default values keys = list(non_default_params.keys()) - for k in keys: + for k in keys: non_default_params.pop(k, None) return non_default_params - raise UnsupportedParamsError( - status_code=500, - message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.", - ) - + raise UnsupportedParamsError(status_code=500, message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.") + final_params = {**non_default_params, **kwargs} return final_params - def get_optional_params( # use the openai defaults # 12 optional params functions=None, @@ -2819,7 +2365,7 @@ def get_optional_params( # use the openai defaults tools=None, tool_choice=None, max_retries=None, - **kwargs, + **kwargs ): # retrieve all parameters passed to the function passed_params = locals() @@ -2829,18 +2375,18 @@ def get_optional_params( # use the openai defaults default_params = { "functions": None, "function_call": None, - "temperature": None, - "top_p": None, - "n": None, - "stream": None, - "stop": None, - "max_tokens": None, - "presence_penalty": None, - "frequency_penalty": None, + "temperature":None, + "top_p":None, + "n":None, + "stream":None, + "stop":None, + "max_tokens":None, + "presence_penalty":None, + "frequency_penalty":None, "logit_bias": None, - "user": None, - "model": None, - "custom_llm_provider": "", + "user":None, + "model":None, + "custom_llm_provider":"", "response_format": None, "seed": None, "tools": None, @@ -2848,82 +2394,64 @@ def get_optional_params( # use the openai defaults "max_retries": None, } # filter out those parameters that were passed with non-default values - non_default_params = { - k: v - for k, v in passed_params.items() - if ( - k != "model" - and k != "custom_llm_provider" - and k in default_params - and v != default_params[k] - ) - } + non_default_params = {k: v for k, v in passed_params.items() if (k != "model" and k != "custom_llm_provider" and k in default_params and v != default_params[k])} optional_params = {} ## raise exception if function calling passed in for a provider that doesn't support it - if "functions" in non_default_params or "function_call" in non_default_params: - if ( - custom_llm_provider != "openai" - and custom_llm_provider != "text-completion-openai" - and custom_llm_provider != "azure" - ): - if ( - litellm.add_function_to_prompt - ): # if user opts to add it to prompt instead - optional_params["functions_unsupported_model"] = non_default_params.pop( - "functions" - ) - else: - raise UnsupportedParamsError( - status_code=500, - message=f"Function calling is not supported by {custom_llm_provider}. To add it to the prompt, set `litellm.add_function_to_prompt = True`.", - ) + if "functions" in non_default_params or "function_call" in non_default_params or "tools" in non_default_params: + if custom_llm_provider != "openai" and custom_llm_provider != "text-completion-openai" and custom_llm_provider != "azure": + if custom_llm_provider == "ollama": + # ollama actually supports json output + optional_params["format"] = "json" + litellm.add_function_to_prompt = True # so that main.py adds the function call to the prompt + if "tools" in non_default_params: + optional_params["functions_unsupported_model"] = non_default_params.pop("tools") + non_default_params.pop("tool_choice", None) # causes ollama requests to hang + elif "functions" in non_default_params: + optional_params["functions_unsupported_model"] = non_default_params.pop("functions") + elif custom_llm_provider == "anyscale" and model == "mistralai/Mistral-7B-Instruct-v0.1": # anyscale just supports function calling with mistral + pass + elif litellm.add_function_to_prompt: # if user opts to add it to prompt instead + optional_params["functions_unsupported_model"] = non_default_params.pop("tools", non_default_params.pop("functions")) + else: + raise UnsupportedParamsError(status_code=500, message=f"Function calling is not supported by {custom_llm_provider}. To add it to the prompt, set `litellm.add_function_to_prompt = True`.") - def _check_valid_arg(supported_params): - print_verbose( - f"\nLiteLLM completion() model= {model}; provider = {custom_llm_provider}" - ) + def _check_valid_arg(supported_params): + print_verbose(f"\nLiteLLM completion() model= {model}; provider = {custom_llm_provider}") print_verbose(f"\nLiteLLM: Params passed to completion() {passed_params}") - print_verbose( - f"\nLiteLLM: Non-Default params passed to completion() {non_default_params}" - ) + print_verbose(f"\nLiteLLM: Non-Default params passed to completion() {non_default_params}") unsupported_params = {} for k in non_default_params.keys(): if k not in supported_params: - if k == "n" and n == 1: # langchain sends n=1 as a default value - continue # skip this param - if ( - k == "max_retries" - ): # TODO: This is a patch. We support max retries for OpenAI, Azure. For non OpenAI LLMs we need to add support for max retries - continue # skip this param + if k == "n" and n == 1: # langchain sends n=1 as a default value + continue # skip this param + if k == "max_retries": # TODO: This is a patch. We support max retries for OpenAI, Azure. For non OpenAI LLMs we need to add support for max retries + continue # skip this param # Always keeps this in elif code blocks - else: + else: unsupported_params[k] = non_default_params[k] if unsupported_params and not litellm.drop_params: - raise UnsupportedParamsError( - status_code=500, - message=f"{custom_llm_provider} does not support parameters: {unsupported_params}. To drop these, set `litellm.drop_params=True`.", - ) - + raise UnsupportedParamsError(status_code=500, message=f"{custom_llm_provider} does not support parameters: {unsupported_params}. To drop these, set `litellm.drop_params=True`.") + def _map_and_modify_arg(supported_params: dict, provider: str, model: str): """ filter params to fit the required provider format, drop those that don't fit if user sets `litellm.drop_params = True`. """ filtered_stop = None - if "stop" in supported_params and litellm.drop_params: - if provider == "bedrock" and "amazon" in model: + if "stop" in supported_params and litellm.drop_params: + if provider == "bedrock" and "amazon" in model: filtered_stop = [] - if isinstance(stop, list): - for s in stop: - if re.match(r"^(\|+|User:)$", s): - filtered_stop.append(s) - if filtered_stop is not None: + if isinstance(stop, list): + for s in stop: + if re.match(r'^(\|+|User:)$', s): + filtered_stop.append(s) + if filtered_stop is not None: supported_params["stop"] = filtered_stop return supported_params - ## raise exception if provider doesn't support passed in param + ## raise exception if provider doesn't support passed in param if custom_llm_provider == "anthropic": - ## check if unsupported param passed in + ## check if unsupported param passed in supported_params = ["stream", "stop", "temperature", "top_p", "max_tokens"] _check_valid_arg(supported_params=supported_params) # handle anthropic params @@ -2931,7 +2459,7 @@ def get_optional_params( # use the openai defaults optional_params["stream"] = stream if stop is not None: if type(stop) == str: - stop = [stop] # openai can accept str/list for stop + stop = [stop] # openai can accept str/list for stop optional_params["stop_sequences"] = stop if temperature is not None: optional_params["temperature"] = temperature @@ -2940,18 +2468,8 @@ def get_optional_params( # use the openai defaults if max_tokens is not None: optional_params["max_tokens_to_sample"] = max_tokens elif custom_llm_provider == "cohere": - ## check if unsupported param passed in - supported_params = [ - "stream", - "temperature", - "max_tokens", - "logit_bias", - "top_p", - "frequency_penalty", - "presence_penalty", - "stop", - "n", - ] + ## check if unsupported param passed in + supported_params = ["stream", "temperature", "max_tokens", "logit_bias", "top_p", "frequency_penalty", "presence_penalty", "stop", "n"] _check_valid_arg(supported_params=supported_params) # handle cohere params if stream: @@ -2973,15 +2491,8 @@ def get_optional_params( # use the openai defaults if stop is not None: optional_params["stop_sequences"] = stop elif custom_llm_provider == "maritalk": - ## check if unsupported param passed in - supported_params = [ - "stream", - "temperature", - "max_tokens", - "top_p", - "presence_penalty", - "stop", - ] + ## check if unsupported param passed in + supported_params = ["stream", "temperature", "max_tokens", "top_p", "presence_penalty", "stop"] _check_valid_arg(supported_params=supported_params) # handle cohere params if stream: @@ -2999,24 +2510,17 @@ def get_optional_params( # use the openai defaults if stop is not None: optional_params["stopping_tokens"] = stop elif custom_llm_provider == "replicate": - ## check if unsupported param passed in - supported_params = [ - "stream", - "temperature", - "max_tokens", - "top_p", - "stop", - "seed", - ] + ## check if unsupported param passed in + supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "seed"] _check_valid_arg(supported_params=supported_params) - + if stream: optional_params["stream"] = stream return optional_params if max_tokens is not None: if "vicuna" in model or "flan" in model: optional_params["max_length"] = max_tokens - elif "meta/codellama-13b" in model: + elif "meta/codellama-13b" in model: optional_params["max_tokens"] = max_tokens else: optional_params["max_new_tokens"] = max_tokens @@ -3027,7 +2531,7 @@ def get_optional_params( # use the openai defaults if stop is not None: optional_params["stop_sequences"] = stop elif custom_llm_provider == "huggingface": - ## check if unsupported param passed in + ## check if unsupported param passed in supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] _check_valid_arg(supported_params=supported_params) # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None @@ -3041,9 +2545,7 @@ def get_optional_params( # use the openai defaults optional_params["top_p"] = top_p if n is not None: optional_params["best_of"] = n - optional_params[ - "do_sample" - ] = True # Need to sample if you want best of for hf inference endpoints + optional_params["do_sample"] = True # Need to sample if you want best of for hf inference endpoints if stream is not None: optional_params["stream"] = stream if stop is not None: @@ -3054,7 +2556,7 @@ def get_optional_params( # use the openai defaults if max_tokens == 0: max_tokens = 1 optional_params["max_new_tokens"] = max_tokens - if n is not None: + if n is not None: optional_params["best_of"] = n if presence_penalty is not None: optional_params["repetition_penalty"] = presence_penalty @@ -3062,21 +2564,12 @@ def get_optional_params( # use the openai defaults # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details # Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False optional_params["decoder_input_details"] = special_params["echo"] - passed_params.pop( - "echo", None - ) # since we handle translating echo, we should not send it to TGI request + passed_params.pop("echo", None) # since we handle translating echo, we should not send it to TGI request elif custom_llm_provider == "together_ai": - ## check if unsupported param passed in - supported_params = [ - "stream", - "temperature", - "max_tokens", - "top_p", - "stop", - "frequency_penalty", - ] + ## check if unsupported param passed in + supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "frequency_penalty"] _check_valid_arg(supported_params=supported_params) - + if stream: optional_params["stream_tokens"] = stream if temperature is not None: @@ -3086,23 +2579,12 @@ def get_optional_params( # use the openai defaults if max_tokens is not None: optional_params["max_tokens"] = max_tokens if frequency_penalty is not None: - optional_params[ - "repetition_penalty" - ] = frequency_penalty # https://docs.together.ai/reference/inference + optional_params["repetition_penalty"] = frequency_penalty # https://docs.together.ai/reference/inference if stop is not None: - optional_params["stop"] = stop + optional_params["stop"] = stop elif custom_llm_provider == "ai21": - ## check if unsupported param passed in - supported_params = [ - "stream", - "n", - "temperature", - "max_tokens", - "top_p", - "stop", - "frequency_penalty", - "presence_penalty", - ] + ## check if unsupported param passed in + supported_params = ["stream", "n", "temperature", "max_tokens", "top_p", "stop", "frequency_penalty", "presence_penalty"] _check_valid_arg(supported_params=supported_params) if stream: @@ -3121,13 +2603,11 @@ def get_optional_params( # use the openai defaults optional_params["frequencyPenalty"] = {"scale": frequency_penalty} if presence_penalty is not None: optional_params["presencePenalty"] = {"scale": presence_penalty} - elif ( - custom_llm_provider == "palm" - ): # https://developers.generativeai.google/tutorials/curl_quickstart - ## check if unsupported param passed in + elif custom_llm_provider == "palm": # https://developers.generativeai.google/tutorials/curl_quickstart + ## check if unsupported param passed in supported_params = ["temperature", "top_p", "stream", "n", "stop", "max_tokens"] _check_valid_arg(supported_params=supported_params) - + if temperature is not None: optional_params["temperature"] = temperature if top_p is not None: @@ -3140,11 +2620,13 @@ def get_optional_params( # use the openai defaults optional_params["stop_sequences"] = stop if max_tokens is not None: optional_params["max_output_tokens"] = max_tokens - elif custom_llm_provider == "vertex_ai": - ## check if unsupported param passed in + elif ( + custom_llm_provider == "vertex_ai" + ): + ## check if unsupported param passed in supported_params = ["temperature", "top_p", "max_tokens", "stream"] _check_valid_arg(supported_params=supported_params) - + if temperature is not None: optional_params["temperature"] = temperature if top_p is not None: @@ -3154,7 +2636,7 @@ def get_optional_params( # use the openai defaults if max_tokens is not None: optional_params["max_output_tokens"] = max_tokens elif custom_llm_provider == "sagemaker": - ## check if unsupported param passed in + ## check if unsupported param passed in supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] _check_valid_arg(supported_params=supported_params) # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None @@ -3168,9 +2650,7 @@ def get_optional_params( # use the openai defaults optional_params["top_p"] = top_p if n is not None: optional_params["best_of"] = n - optional_params[ - "do_sample" - ] = True # Need to sample if you want best of for hf inference endpoints + optional_params["do_sample"] = True # Need to sample if you want best of for hf inference endpoints if stream is not None: optional_params["stream"] = stream if stop is not None: @@ -3193,7 +2673,7 @@ def get_optional_params( # use the openai defaults optional_params["temperature"] = temperature if top_p is not None: optional_params["topP"] = top_p - if stream: + if stream: optional_params["stream"] = stream elif "anthropic" in model: supported_params = ["max_tokens", "temperature", "stop", "top_p", "stream"] @@ -3208,9 +2688,9 @@ def get_optional_params( # use the openai defaults optional_params["top_p"] = top_p if stop is not None: optional_params["stop_sequences"] = stop - if stream: + if stream: optional_params["stream"] = stream - elif "amazon" in model: # amazon titan llms + elif "amazon" in model: # amazon titan llms supported_params = ["max_tokens", "temperature", "stop", "top_p", "stream"] _check_valid_arg(supported_params=supported_params) # see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large @@ -3219,15 +2699,13 @@ def get_optional_params( # use the openai defaults if temperature is not None: optional_params["temperature"] = temperature if stop is not None: - filtered_stop = _map_and_modify_arg( - {"stop": stop}, provider="bedrock", model=model - ) + filtered_stop = _map_and_modify_arg({"stop": stop}, provider="bedrock", model=model) optional_params["stopSequences"] = filtered_stop["stop"] if top_p is not None: optional_params["topP"] = top_p - if stream: + if stream: optional_params["stream"] = stream - elif "meta" in model: # amazon / meta llms + elif "meta" in model: # amazon / meta llms supported_params = ["max_tokens", "temperature", "top_p", "stream"] _check_valid_arg(supported_params=supported_params) # see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large @@ -3237,9 +2715,9 @@ def get_optional_params( # use the openai defaults optional_params["temperature"] = temperature if top_p is not None: optional_params["top_p"] = top_p - if stream: + if stream: optional_params["stream"] = stream - elif "cohere" in model: # cohere models on bedrock + elif "cohere" in model: # cohere models on bedrock supported_params = ["stream", "temperature", "max_tokens"] _check_valid_arg(supported_params=supported_params) # handle cohere params @@ -3250,16 +2728,7 @@ def get_optional_params( # use the openai defaults if max_tokens is not None: optional_params["max_tokens"] = max_tokens elif custom_llm_provider == "aleph_alpha": - supported_params = [ - "max_tokens", - "stream", - "top_p", - "temperature", - "presence_penalty", - "frequency_penalty", - "n", - "stop", - ] + supported_params = ["max_tokens", "stream", "top_p", "temperature", "presence_penalty", "frequency_penalty", "n", "stop"] _check_valid_arg(supported_params=supported_params) if max_tokens is not None: optional_params["maximum_tokens"] = max_tokens @@ -3278,16 +2747,9 @@ def get_optional_params( # use the openai defaults if stop is not None: optional_params["stop_sequences"] = stop elif custom_llm_provider == "ollama": - supported_params = [ - "max_tokens", - "stream", - "top_p", - "temperature", - "frequency_penalty", - "stop", - ] + supported_params = ["max_tokens", "stream", "top_p", "temperature", "frequency_penalty", "stop"] _check_valid_arg(supported_params=supported_params) - + if max_tokens is not None: optional_params["num_predict"] = max_tokens if stream: @@ -3301,16 +2763,7 @@ def get_optional_params( # use the openai defaults if stop is not None: optional_params["stop_sequences"] = stop elif custom_llm_provider == "nlp_cloud": - supported_params = [ - "max_tokens", - "stream", - "temperature", - "top_p", - "presence_penalty", - "frequency_penalty", - "n", - "stop", - ] + supported_params = ["max_tokens", "stream", "temperature", "top_p", "presence_penalty", "frequency_penalty", "n", "stop"] _check_valid_arg(supported_params=supported_params) if max_tokens is not None: @@ -3342,84 +2795,62 @@ def get_optional_params( # use the openai defaults if stream: optional_params["stream"] = stream elif custom_llm_provider == "deepinfra": - supported_params = [ - "temperature", - "top_p", - "n", - "stream", - "stop", - "max_tokens", - "presence_penalty", - "frequency_penalty", - "logit_bias", - "user", - ] + supported_params = ["temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user"] _check_valid_arg(supported_params=supported_params) if temperature is not None: - if ( - temperature == 0 and model == "mistralai/Mistral-7B-Instruct-v0.1" - ): # this model does no support temperature == 0 - temperature = 0.0001 # close to 0 + if temperature == 0 and model == "mistralai/Mistral-7B-Instruct-v0.1": # this model does no support temperature == 0 + temperature = 0.0001 # close to 0 optional_params["temperature"] = temperature if top_p: optional_params["top_p"] = top_p - if n: + if n: optional_params["n"] = n - if stream: + if stream: optional_params["stream"] = stream - if stop: + if stop: optional_params["stop"] = stop - if max_tokens: + if max_tokens: optional_params["max_tokens"] = max_tokens - if presence_penalty: + if presence_penalty: optional_params["presence_penalty"] = presence_penalty - if frequency_penalty: + if frequency_penalty: optional_params["frequency_penalty"] = frequency_penalty - if logit_bias: + if logit_bias: optional_params["logit_bias"] = logit_bias - if user: + if user: optional_params["user"] = user elif custom_llm_provider == "perplexity": - supported_params = [ - "temperature", - "top_p", - "stream", - "max_tokens", - "presence_penalty", - "frequency_penalty", - ] + supported_params = ["temperature", "top_p", "stream", "max_tokens", "presence_penalty", "frequency_penalty"] _check_valid_arg(supported_params=supported_params) if temperature is not None: - if ( - temperature == 0 and model == "mistral-7b-instruct" - ): # this model does no support temperature == 0 - temperature = 0.0001 # close to 0 + if temperature == 0 and model == "mistral-7b-instruct": # this model does no support temperature == 0 + temperature = 0.0001 # close to 0 optional_params["temperature"] = temperature - if top_p: + if top_p: optional_params["top_p"] = top_p - if stream: + if stream: optional_params["stream"] = stream - if max_tokens: + if max_tokens: optional_params["max_tokens"] = max_tokens - if presence_penalty: + if presence_penalty: optional_params["presence_penalty"] = presence_penalty - if frequency_penalty: + if frequency_penalty: optional_params["frequency_penalty"] = frequency_penalty elif custom_llm_provider == "anyscale": - supported_params = ["temperature", "top_p", "stream", "max_tokens"] + supported_params = ["temperature", "top_p", "stream", "max_tokens", "stop", "frequency_penalty", "presence_penalty"] + if model == "mistralai/Mistral-7B-Instruct-v0.1": + supported_params += ["functions", "function_call", "tools", "tool_choice"] _check_valid_arg(supported_params=supported_params) optional_params = non_default_params if temperature is not None: - if ( - temperature == 0 and model == "mistralai/Mistral-7B-Instruct-v0.1" - ): # this model does no support temperature == 0 - temperature = 0.0001 # close to 0 + if temperature == 0 and model == "mistralai/Mistral-7B-Instruct-v0.1": # this model does no support temperature == 0 + temperature = 0.0001 # close to 0 optional_params["temperature"] = temperature - if top_p: + if top_p: optional_params["top_p"] = top_p - if stream: + if stream: optional_params["stream"] = stream - if max_tokens: + if max_tokens: optional_params["max_tokens"] = max_tokens elif custom_llm_provider == "mistral": supported_params = ["temperature", "top_p", "stream", "max_tokens"] @@ -3427,13 +2858,13 @@ def get_optional_params( # use the openai defaults optional_params = non_default_params if temperature is not None: optional_params["temperature"] = temperature - if top_p is not None: + if top_p is not None: optional_params["top_p"] = top_p - if stream is not None: + if stream is not None: optional_params["stream"] = stream - if max_tokens is not None: + if max_tokens is not None: optional_params["max_tokens"] = max_tokens - + # check safe_mode, random_seed: https://docs.mistral.ai/api/#operation/createChatCompletion safe_mode = passed_params.pop("safe_mode", None) random_seed = passed_params.pop("random_seed", None) @@ -3442,29 +2873,9 @@ def get_optional_params( # use the openai defaults extra_body["safe_mode"] = safe_mode if random_seed is not None: extra_body["random_seed"] = random_seed - optional_params[ - "extra_body" - ] = extra_body # openai client supports `extra_body` param + optional_params["extra_body"] = extra_body # openai client supports `extra_body` param elif custom_llm_provider == "openrouter": - supported_params = [ - "functions", - "function_call", - "temperature", - "top_p", - "n", - "stream", - "stop", - "max_tokens", - "presence_penalty", - "frequency_penalty", - "logit_bias", - "user", - "response_format", - "seed", - "tools", - "tool_choice", - "max_retries", - ] + supported_params = ["functions", "function_call", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice", "max_retries"] _check_valid_arg(supported_params=supported_params) if functions is not None: @@ -3501,7 +2912,7 @@ def get_optional_params( # use the openai defaults optional_params["tool_choice"] = tool_choice if max_retries is not None: optional_params["max_retries"] = max_retries - + # OpenRouter-only parameters extra_body = {} transforms = passed_params.pop("transforms", None) @@ -3513,29 +2924,9 @@ def get_optional_params( # use the openai defaults extra_body["models"] = models if route is not None: extra_body["route"] = route - optional_params[ - "extra_body" - ] = extra_body # openai client supports `extra_body` param + optional_params["extra_body"] = extra_body # openai client supports `extra_body` param else: # assume passing in params for openai/azure openai - supported_params = [ - "functions", - "function_call", - "temperature", - "top_p", - "n", - "stream", - "stop", - "max_tokens", - "presence_penalty", - "frequency_penalty", - "logit_bias", - "user", - "response_format", - "seed", - "tools", - "tool_choice", - "max_retries", - ] + supported_params = ["functions", "function_call", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice", "max_retries"] _check_valid_arg(supported_params=supported_params) if functions is not None: optional_params["functions"] = functions @@ -3572,44 +2963,35 @@ def get_optional_params( # use the openai defaults if max_retries is not None: optional_params["max_retries"] = max_retries optional_params = non_default_params - # if user passed in non-default kwargs for specific providers/models, pass them along - for k in passed_params.keys(): - if k not in default_params.keys(): + # if user passed in non-default kwargs for specific providers/models, pass them along + for k in passed_params.keys(): + if k not in default_params.keys(): optional_params[k] = passed_params[k] return optional_params - -def get_llm_provider( - model: str, - custom_llm_provider: Optional[str] = None, - api_base: Optional[str] = None, - api_key: Optional[str] = None, -): +def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_base: Optional[str] = None, api_key: Optional[str] = None): try: dynamic_api_key = None # check if llm provider provided - + if custom_llm_provider: return model, custom_llm_provider, dynamic_api_key, api_base - - if api_key and api_key.startswith("os.environ/"): + + if api_key and api_key.startswith("os.environ/"): dynamic_api_key = get_secret(api_key) # check if llm provider part of model name - if ( - model.split("/", 1)[0] in litellm.provider_list - and model.split("/", 1)[0] not in litellm.model_list - ): + if model.split("/",1)[0] in litellm.provider_list and model.split("/",1)[0] not in litellm.model_list: custom_llm_provider = model.split("/", 1)[0] model = model.split("/", 1)[1] if custom_llm_provider == "perplexity": # perplexity is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.perplexity.ai api_base = "https://api.perplexity.ai" dynamic_api_key = get_secret("PERPLEXITYAI_API_KEY") - elif custom_llm_provider == "anyscale": + elif custom_llm_provider == "anyscale": # anyscale is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 api_base = "https://api.endpoints.anyscale.com/v1" dynamic_api_key = get_secret("ANYSCALE_API_KEY") - elif custom_llm_provider == "deepinfra": + elif custom_llm_provider == "deepinfra": # deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 api_base = "https://api.deepinfra.com/v1/openai" dynamic_api_key = get_secret("DEEPINFRA_API_KEY") @@ -3620,7 +3002,7 @@ def get_llm_provider( return model, custom_llm_provider, dynamic_api_key, api_base # check if api base is a known openai compatible endpoint - if api_base: + if api_base: for endpoint in litellm.openai_compatible_endpoints: if endpoint in api_base: if endpoint == "api.perplexity.ai": @@ -3639,26 +3021,20 @@ def get_llm_provider( # check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.) ## openai - chatcompletion + text completion - if ( - model in litellm.open_ai_chat_completion_models - or "ft:gpt-3.5-turbo" in model - or model in litellm.openai_image_generation_models - ): + if model in litellm.open_ai_chat_completion_models or "ft:gpt-3.5-turbo" in model or model in litellm.openai_image_generation_models: custom_llm_provider = "openai" elif model in litellm.open_ai_text_completion_models: custom_llm_provider = "text-completion-openai" - ## anthropic + ## anthropic elif model in litellm.anthropic_models: custom_llm_provider = "anthropic" ## cohere elif model in litellm.cohere_models or model in litellm.cohere_embedding_models: custom_llm_provider = "cohere" ## replicate - elif model in litellm.replicate_models or (":" in model and len(model) > 64): + elif model in litellm.replicate_models or (":" in model and len(model)>64): model_parts = model.split(":") - if ( - len(model_parts) > 1 and len(model_parts[1]) == 64 - ): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3" + if len(model_parts) > 1 and len(model_parts[1])==64: ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3" custom_llm_provider = "replicate" elif model in litellm.replicate_models: custom_llm_provider = "replicate" @@ -3668,22 +3044,22 @@ def get_llm_provider( ## openrouter elif model in litellm.maritalk_models: custom_llm_provider = "maritalk" - ## vertex - text + chat + language (gemini) models - elif ( - model in litellm.vertex_chat_models - or model in litellm.vertex_code_chat_models - or model in litellm.vertex_text_models - or model in litellm.vertex_code_text_models - or model in litellm.vertex_language_models + ## vertex - text + chat + language (gemini) models + elif( + model in litellm.vertex_chat_models or + model in litellm.vertex_code_chat_models or + model in litellm.vertex_text_models or + model in litellm.vertex_code_text_models or + model in litellm.vertex_language_models ): custom_llm_provider = "vertex_ai" - ## ai21 + ## ai21 elif model in litellm.ai21_models: custom_llm_provider = "ai21" - ## aleph_alpha + ## aleph_alpha elif model in litellm.aleph_alpha_models: custom_llm_provider = "aleph_alpha" - ## baseten + ## baseten elif model in litellm.baseten_models: custom_llm_provider = "baseten" ## nlp_cloud @@ -3693,80 +3069,107 @@ def get_llm_provider( elif model in litellm.petals_models: custom_llm_provider = "petals" ## bedrock - elif ( - model in litellm.bedrock_models or model in litellm.bedrock_embedding_models - ): + elif model in litellm.bedrock_models or model in litellm.bedrock_embedding_models: custom_llm_provider = "bedrock" # openai embeddings elif model in litellm.open_ai_embedding_models: custom_llm_provider = "openai" - if custom_llm_provider is None or custom_llm_provider == "": - print() # noqa - print( - "\033[1;31mProvider List: https://docs.litellm.ai/docs/providers\033[0m" - ) # noqa - print() # noqa + if custom_llm_provider is None or custom_llm_provider=="": + print() # noqa + print("\033[1;31mProvider List: https://docs.litellm.ai/docs/providers\033[0m") # noqa + print() # noqa error_str = f"LLM Provider NOT provided. Pass in the LLM provider you are trying to call. You passed model={model}\n Pass model as E.g. For 'Huggingface' inference endpoints pass in `completion(model='huggingface/starcoder',..)` Learn more: https://docs.litellm.ai/docs/providers" # maps to openai.NotFoundError, this is raised when openai does not recognize the llm - raise litellm.exceptions.NotFoundError( # type: ignore + raise litellm.exceptions.NotFoundError( # type: ignore message=error_str, model=model, response=httpx.Response( status_code=404, - content=error_str, - request=httpx.request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore + content=error_str, + request=httpx.request(method="completion", url="https://github.com/BerriAI/litellm") # type: ignore ), - llm_provider="", + llm_provider="" ) return model, custom_llm_provider, dynamic_api_key, api_base - except Exception as e: + except Exception as e: raise e def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]): - api_key = dynamic_api_key or litellm.api_key - # openai + api_key = (dynamic_api_key or litellm.api_key) + # openai if llm_provider == "openai" or llm_provider == "text-completion-openai": - api_key = api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") - # anthropic + api_key = ( + api_key or + litellm.openai_key or + get_secret("OPENAI_API_KEY") + ) + # anthropic elif llm_provider == "anthropic": - api_key = api_key or litellm.anthropic_key or get_secret("ANTHROPIC_API_KEY") - # ai21 + api_key = ( + api_key or + litellm.anthropic_key or + get_secret("ANTHROPIC_API_KEY") + ) + # ai21 elif llm_provider == "ai21": - api_key = api_key or litellm.ai21_key or get_secret("AI211_API_KEY") - # aleph_alpha + api_key = ( + api_key or + litellm.ai21_key or + get_secret("AI211_API_KEY") + ) + # aleph_alpha elif llm_provider == "aleph_alpha": api_key = ( - api_key or litellm.aleph_alpha_key or get_secret("ALEPH_ALPHA_API_KEY") + api_key or + litellm.aleph_alpha_key or + get_secret("ALEPH_ALPHA_API_KEY") ) - # baseten + # baseten elif llm_provider == "baseten": - api_key = api_key or litellm.baseten_key or get_secret("BASETEN_API_KEY") - # cohere + api_key = ( + api_key or + litellm.baseten_key or + get_secret("BASETEN_API_KEY") + ) + # cohere elif llm_provider == "cohere": - api_key = api_key or litellm.cohere_key or get_secret("COHERE_API_KEY") - # huggingface + api_key = ( + api_key or + litellm.cohere_key or + get_secret("COHERE_API_KEY") + ) + # huggingface elif llm_provider == "huggingface": api_key = ( - api_key or litellm.huggingface_key or get_secret("HUGGINGFACE_API_KEY") + api_key or + litellm.huggingface_key or + get_secret("HUGGINGFACE_API_KEY") ) - # nlp_cloud + # nlp_cloud elif llm_provider == "nlp_cloud": - api_key = api_key or litellm.nlp_cloud_key or get_secret("NLP_CLOUD_API_KEY") - # replicate + api_key = ( + api_key or + litellm.nlp_cloud_key or + get_secret("NLP_CLOUD_API_KEY") + ) + # replicate elif llm_provider == "replicate": - api_key = api_key or litellm.replicate_key or get_secret("REPLICATE_API_KEY") - # together_ai + api_key = ( + api_key or + litellm.replicate_key or + get_secret("REPLICATE_API_KEY") + ) + # together_ai elif llm_provider == "together_ai": api_key = ( - api_key - or litellm.togetherai_api_key - or get_secret("TOGETHERAI_API_KEY") - or get_secret("TOGETHER_AI_TOKEN") + api_key or + litellm.togetherai_api_key or + get_secret("TOGETHERAI_API_KEY") or + get_secret("TOGETHER_AI_TOKEN") ) return api_key - def get_max_tokens(model: str): """ Get the maximum number of tokens allowed for a given model. @@ -3784,7 +3187,6 @@ def get_max_tokens(model: str): >>> get_max_tokens("gpt-4") 8192 """ - def _get_max_position_embeddings(model_name): # Construct the URL for the config.json file config_url = f"https://huggingface.co/{model_name}/raw/main/config.json" @@ -3810,21 +3212,19 @@ def get_max_tokens(model: str): try: if model in litellm.model_cost: return litellm.model_cost[model]["max_tokens"] - model, custom_llm_provider, _, _ = get_llm_provider(model=model) - if custom_llm_provider == "huggingface": + model, custom_llm_provider, _, _ = get_llm_provider(model=model) + if custom_llm_provider == "huggingface": max_tokens = _get_max_position_embeddings(model_name=model) return max_tokens - else: + else: raise Exception() except: - raise Exception( - "This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" - ) + raise Exception("This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json") def get_model_info(model: str): """ - Get a dict for the maximum tokens (context window), + Get a dict for the maximum tokens (context window), input_cost_per_token, output_cost_per_token for a given model. Parameters: @@ -3851,7 +3251,6 @@ def get_model_info(model: str): "mode": "chat" } """ - def _get_max_position_embeddings(model_name): # Construct the URL for the config.json file config_url = f"https://huggingface.co/{model_name}/raw/main/config.json" @@ -3873,34 +3272,30 @@ def get_model_info(model: str): return None except requests.exceptions.RequestException as e: return None - try: azure_llms = { "gpt-35-turbo": "azure/gpt-3.5-turbo", "gpt-35-turbo-16k": "azure/gpt-3.5-turbo-16k", - "gpt-35-turbo-instruct": "azure/gpt-3.5-turbo-instruct", + "gpt-35-turbo-instruct": "azure/gpt-3.5-turbo-instruct" } - if model in azure_llms: + if model in azure_llms: model = azure_llms[model] if model in litellm.model_cost: return litellm.model_cost[model] - model, custom_llm_provider, _, _ = get_llm_provider(model=model) - if custom_llm_provider == "huggingface": + model, custom_llm_provider, _, _ = get_llm_provider(model=model) + if custom_llm_provider == "huggingface": max_tokens = _get_max_position_embeddings(model_name=model) return { "max_tokens": max_tokens, "input_cost_per_token": 0, "output_cost_per_token": 0, "litellm_provider": "huggingface", - "mode": "chat", + "mode": "chat" } - else: + else: raise Exception() except: - raise Exception( - "This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" - ) - + raise Exception("This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json") def json_schema_type(python_type_name: str): """Converts standard python types to json schema types @@ -3927,7 +3322,6 @@ def json_schema_type(python_type_name: str): return python_to_json_schema_types.get(python_type_name, "string") - def function_to_dict(input_function): # noqa: C901 """Using type hints and numpy-styled docstring, produce a dictionnary usable for OpenAI function calling @@ -4017,7 +3411,6 @@ def function_to_dict(input_function): # noqa: C901 return result - def load_test_model( model: str, custom_llm_provider: str = "", @@ -4060,14 +3453,13 @@ def load_test_model( "exception": e, } - -def validate_environment(model: Optional[str] = None) -> dict: +def validate_environment(model: Optional[str]=None) -> dict: """ Checks if the environment variables are valid for the given model. - + Args: model (Optional[str]): The name of the model. Defaults to None. - + Returns: dict: A dictionary containing the following keys: - keys_in_environment (bool): True if all the required keys are present in the environment, False otherwise. @@ -4077,10 +3469,7 @@ def validate_environment(model: Optional[str] = None) -> dict: missing_keys: List[str] = [] if model is None: - return { - "keys_in_environment": keys_in_environment, - "missing_keys": missing_keys, - } + return {"keys_in_environment": keys_in_environment, "missing_keys": missing_keys} ## EXTRACT LLM PROVIDER - if model name provided try: custom_llm_provider = get_llm_provider(model=model) @@ -4091,7 +3480,7 @@ def validate_environment(model: Optional[str] = None) -> dict: # custom_llm_provider = model.split("/", 1)[0] # model = model.split("/", 1)[1] # custom_llm_provider_passed_in = True - + if custom_llm_provider: if custom_llm_provider == "openai": if "OPENAI_API_KEY" in os.environ: @@ -4099,16 +3488,12 @@ def validate_environment(model: Optional[str] = None) -> dict: else: missing_keys.append("OPENAI_API_KEY") elif custom_llm_provider == "azure": - if ( - "AZURE_API_BASE" in os.environ + if ("AZURE_API_BASE" in os.environ and "AZURE_API_VERSION" in os.environ - and "AZURE_API_KEY" in os.environ - ): + and "AZURE_API_KEY" in os.environ): keys_in_environment = True else: - missing_keys.extend( - ["AZURE_API_BASE", "AZURE_API_VERSION", "AZURE_API_KEY"] - ) + missing_keys.extend(["AZURE_API_BASE", "AZURE_API_VERSION", "AZURE_API_KEY"]) elif custom_llm_provider == "anthropic": if "ANTHROPIC_API_KEY" in os.environ: keys_in_environment = True @@ -4130,7 +3515,8 @@ def validate_environment(model: Optional[str] = None) -> dict: else: missing_keys.append("OPENROUTER_API_KEY") elif custom_llm_provider == "vertex_ai": - if "VERTEXAI_PROJECT" in os.environ and "VERTEXAI_LOCATION" in os.environ: + if ("VERTEXAI_PROJECT" in os.environ + and "VERTEXAI_LOCATION" in os.environ): keys_in_environment = True else: missing_keys.extend(["VERTEXAI_PROJECT", "VERTEXAI_PROJECT"]) @@ -4164,26 +3550,20 @@ def validate_environment(model: Optional[str] = None) -> dict: keys_in_environment = True else: missing_keys.append("NLP_CLOUD_API_KEY") - elif custom_llm_provider == "bedrock": - if ( - "AWS_ACCESS_KEY_ID" in os.environ - and "AWS_SECRET_ACCESS_KEY" in os.environ - ): + elif custom_llm_provider == "bedrock": + if "AWS_ACCESS_KEY_ID" in os.environ and "AWS_SECRET_ACCESS_KEY" in os.environ: keys_in_environment = True else: missing_keys.append("AWS_ACCESS_KEY_ID") missing_keys.append("AWS_SECRET_ACCESS_KEY") else: ## openai - chatcompletion + text completion - if ( - model in litellm.open_ai_chat_completion_models - or litellm.open_ai_text_completion_models - ): + if model in litellm.open_ai_chat_completion_models or litellm.open_ai_text_completion_models: if "OPENAI_API_KEY" in os.environ: keys_in_environment = True else: missing_keys.append("OPENAI_API_KEY") - ## anthropic + ## anthropic elif model in litellm.anthropic_models: if "ANTHROPIC_API_KEY" in os.environ: keys_in_environment = True @@ -4209,35 +3589,36 @@ def validate_environment(model: Optional[str] = None) -> dict: missing_keys.append("OPENROUTER_API_KEY") ## vertex - text + chat models elif model in litellm.vertex_chat_models or model in litellm.vertex_text_models: - if "VERTEXAI_PROJECT" in os.environ and "VERTEXAI_LOCATION" in os.environ: + if ("VERTEXAI_PROJECT" in os.environ + and "VERTEXAI_LOCATION" in os.environ): keys_in_environment = True else: missing_keys.extend(["VERTEXAI_PROJECT", "VERTEXAI_PROJECT"]) - ## huggingface + ## huggingface elif model in litellm.huggingface_models: if "HUGGINGFACE_API_KEY" in os.environ: keys_in_environment = True else: missing_keys.append("HUGGINGFACE_API_KEY") - ## ai21 + ## ai21 elif model in litellm.ai21_models: if "AI21_API_KEY" in os.environ: keys_in_environment = True else: missing_keys.append("AI21_API_KEY") - ## together_ai + ## together_ai elif model in litellm.together_ai_models: if "TOGETHERAI_API_KEY" in os.environ: keys_in_environment = True else: missing_keys.append("TOGETHERAI_API_KEY") - ## aleph_alpha + ## aleph_alpha elif model in litellm.aleph_alpha_models: if "ALEPH_ALPHA_API_KEY" in os.environ: keys_in_environment = True else: missing_keys.append("ALEPH_ALPHA_API_KEY") - ## baseten + ## baseten elif model in litellm.baseten_models: if "BASETEN_API_KEY" in os.environ: keys_in_environment = True @@ -4249,8 +3630,7 @@ def validate_environment(model: Optional[str] = None) -> dict: keys_in_environment = True else: missing_keys.append("NLP_CLOUD_API_KEY") - return {"keys_in_environment": keys_in_environment, "missing_keys": missing_keys} - + return {"keys_in_environment": keys_in_environment, "missing_keys": missing_keys} def set_callbacks(callback_list, function_id=None): global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, dynamoLogger @@ -4344,7 +3724,6 @@ def set_callbacks(callback_list, function_id=None): except Exception as e: raise e - # NOTE: DEPRECATING this in favor of using failure_handler() in Logging: def handle_failure(exception, traceback_exception, start_time, end_time, args, kwargs): global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger @@ -4488,8 +3867,7 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k exception_logging(logger_fn=user_logger_fn, exception=e) pass - -async def convert_to_streaming_response_async(response_object: Optional[dict] = None): +async def convert_to_streaming_response_async(response_object: Optional[dict]=None): """ Asynchronously converts a response object to a streaming response. @@ -4520,7 +3898,7 @@ async def convert_to_streaming_response_async(response_object: Optional[dict] = content=choice["message"].get("content", None), role=choice["message"]["role"], function_call=choice["message"].get("function_call", None), - tool_calls=choice["message"].get("tool_calls", None), + tool_calls=choice["message"].get("tool_calls", None) ) finish_reason = choice.get("finish_reason", None) @@ -4536,9 +3914,10 @@ async def convert_to_streaming_response_async(response_object: Optional[dict] = model_response_object.usage = Usage( completion_tokens=response_object["usage"].get("completion_tokens", 0), prompt_tokens=response_object["usage"].get("prompt_tokens", 0), - total_tokens=response_object["usage"].get("total_tokens", 0), + total_tokens=response_object["usage"].get("total_tokens", 0) ) + if "id" in response_object: model_response_object.id = response_object["id"] @@ -4551,20 +3930,19 @@ async def convert_to_streaming_response_async(response_object: Optional[dict] = yield model_response_object await asyncio.sleep(0) - -def convert_to_streaming_response(response_object: Optional[dict] = None): +def convert_to_streaming_response(response_object: Optional[dict]=None): # used for yielding Cache hits when stream == True if response_object is None: raise Exception("Error in response object format") model_response_object = ModelResponse(stream=True) - choice_list = [] - for idx, choice in enumerate(response_object["choices"]): + choice_list=[] + for idx, choice in enumerate(response_object["choices"]): delta = Delta( - content=choice["message"].get("content", None), - role=choice["message"]["role"], - function_call=choice["message"].get("function_call", None), - tool_calls=choice["message"].get("tool_calls", None), + content=choice["message"].get("content", None), + role=choice["message"]["role"], + function_call=choice["message"].get("function_call", None), + tool_calls=choice["message"].get("tool_calls", None) ) finish_reason = choice.get("finish_reason", None) if finish_reason == None: @@ -4575,118 +3953,100 @@ def convert_to_streaming_response(response_object: Optional[dict] = None): model_response_object.choices = choice_list if "usage" in response_object and response_object["usage"] is not None: - model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore - model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore - model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore + model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore + model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore + model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore - if "id" in response_object: + if "id" in response_object: model_response_object.id = response_object["id"] if "system_fingerprint" in response_object: model_response_object.system_fingerprint = response_object["system_fingerprint"] - if "model" in response_object: + if "model" in response_object: model_response_object.model = response_object["model"] yield model_response_object -def convert_to_model_response_object( - response_object: Optional[dict] = None, - model_response_object: Optional[ - Union[ModelResponse, EmbeddingResponse, ImageResponse] - ] = None, - response_type: Literal[ - "completion", "embedding", "image_generation" - ] = "completion", - stream=False, -): - try: - if response_type == "completion" and ( - model_response_object is None - or isinstance(model_response_object, ModelResponse) - ): - if response_object is None or model_response_object is None: - raise Exception("Error in response object format") - if stream == True: - # for returning cached responses, we need to yield a generator - return convert_to_streaming_response(response_object=response_object) - choice_list = [] - for idx, choice in enumerate(response_object["choices"]): - message = Message( - content=choice["message"].get("content", None), - role=choice["message"]["role"], - function_call=choice["message"].get("function_call", None), - tool_calls=choice["message"].get("tool_calls", None), - ) - finish_reason = choice.get("finish_reason", None) - if finish_reason == None: - # gpt-4 vision can return 'finish_reason' or 'finish_details' - finish_reason = choice.get("finish_details") - choice = Choices( - finish_reason=finish_reason, index=idx, message=message - ) - choice_list.append(choice) - model_response_object.choices = choice_list +def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse, ImageResponse]]=None, response_type: Literal["completion", "embedding", "image_generation"] = "completion", stream = False): + try: + if response_type == "completion" and (model_response_object is None or isinstance(model_response_object, ModelResponse)): + if response_object is None or model_response_object is None: + raise Exception("Error in response object format") + if stream == True: + # for returning cached responses, we need to yield a generator + return convert_to_streaming_response( + response_object=response_object + ) + choice_list=[] + for idx, choice in enumerate(response_object["choices"]): + message = Message( + content=choice["message"].get("content", None), + role=choice["message"]["role"], + function_call=choice["message"].get("function_call", None), + tool_calls=choice["message"].get("tool_calls", None) + ) + finish_reason = choice.get("finish_reason", None) + if finish_reason == None: + # gpt-4 vision can return 'finish_reason' or 'finish_details' + finish_reason = choice.get("finish_details") + choice = Choices(finish_reason=finish_reason, index=idx, message=message) + choice_list.append(choice) + model_response_object.choices = choice_list - if "usage" in response_object and response_object["usage"] is not None: - model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore - model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore - model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore + if "usage" in response_object and response_object["usage"] is not None: + model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore + model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore + model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore - if "id" in response_object: - model_response_object.id = response_object["id"] + if "id" in response_object: + model_response_object.id = response_object["id"] + + if "system_fingerprint" in response_object: + model_response_object.system_fingerprint = response_object["system_fingerprint"] - if "system_fingerprint" in response_object: - model_response_object.system_fingerprint = response_object[ - "system_fingerprint" - ] + if "model" in response_object: + model_response_object.model = response_object["model"] + return model_response_object + elif response_type == "embedding" and (model_response_object is None or isinstance(model_response_object, EmbeddingResponse)): + if response_object is None: + raise Exception("Error in response object format") + + if model_response_object is None: + model_response_object = EmbeddingResponse() - if "model" in response_object: - model_response_object.model = response_object["model"] - return model_response_object - elif response_type == "embedding" and ( - model_response_object is None - or isinstance(model_response_object, EmbeddingResponse) - ): - if response_object is None: - raise Exception("Error in response object format") + if "model" in response_object: + model_response_object.model = response_object["model"] + + if "object" in response_object: + model_response_object.object = response_object["object"] - if model_response_object is None: - model_response_object = EmbeddingResponse() - - if "model" in response_object: - model_response_object.model = response_object["model"] - - if "object" in response_object: - model_response_object.object = response_object["object"] - - model_response_object.data = response_object["data"] - - if "usage" in response_object and response_object["usage"] is not None: - model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore - model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore - model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore - - return model_response_object - elif response_type == "image_generation" and ( - model_response_object is None - or isinstance(model_response_object, ImageResponse) - ): - if response_object is None: - raise Exception("Error in response object format") - - if model_response_object is None: - model_response_object = ImageResponse() - - if "created" in response_object: - model_response_object.created = response_object["created"] - - if "data" in response_object: + model_response_object.data = response_object["data"] - return model_response_object - except Exception as e: - raise Exception(f"Invalid response object {e}") + if "usage" in response_object and response_object["usage"] is not None: + model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore + model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore + model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore + + + return model_response_object + elif response_type == "image_generation" and (model_response_object is None or isinstance(model_response_object, ImageResponse)): + if response_object is None: + raise Exception("Error in response object format") + + if model_response_object is None: + model_response_object = ImageResponse() + + if "created" in response_object: + model_response_object.created = response_object["created"] + + if "data" in response_object: + model_response_object.data = response_object["data"] + + return model_response_object + except Exception as e: + raise Exception(f"Invalid response object {e}") # NOTE: DEPRECATING this in favor of using success_handler() in Logging: @@ -4792,7 +4152,6 @@ def valid_model(model): except: raise BadRequestError(message="", model=model, llm_provider="") - def check_valid_key(model: str, api_key: str): """ Checks if a given API key is valid for a specific model by making a litellm.completion call with max_tokens=10 @@ -4806,19 +4165,16 @@ def check_valid_key(model: str, api_key: str): """ messages = [{"role": "user", "content": "Hey, how's it going?"}] try: - litellm.completion( - model=model, messages=messages, api_key=api_key, max_tokens=10 - ) + litellm.completion(model=model, messages=messages, api_key=api_key, max_tokens=10) return True except AuthenticationError as e: return False except Exception as e: return False - -def _should_retry(status_code: int): +def _should_retry(status_code: int): """ - Reimplementation of openai's should retry logic, since that one can't be imported. + Reimplementation of openai's should retry logic, since that one can't be imported. https://github.com/openai/openai-python/blob/af67cfab4210d8e497c05390ce14f39105c77519/src/openai/_base_client.py#L639 """ # If the server explicitly says whether or not to retry, obey. @@ -4840,20 +4196,13 @@ def _should_retry(status_code: int): return False - -def _calculate_retry_after( - remaining_retries: int, - max_retries: int, - response_headers: Optional[httpx.Headers] = None, - min_timeout: int = 0, -): +def _calculate_retry_after(remaining_retries: int, max_retries: int, response_headers: Optional[httpx.Headers]=None, min_timeout: int = 0): """ Reimplementation of openai's calculate retry after, since that one can't be imported. https://github.com/openai/openai-python/blob/af67cfab4210d8e497c05390ce14f39105c77519/src/openai/_base_client.py#L631 """ try: - import email # openai import - + import email # openai import # About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After # # ". See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After#syntax for @@ -4874,7 +4223,7 @@ def _calculate_retry_after( except Exception: retry_after = -1 - + # If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says. if 0 < retry_after <= 60: return retry_after @@ -4891,7 +4240,6 @@ def _calculate_retry_after( timeout = sleep_seconds * jitter return timeout if timeout >= min_timeout else min_timeout - # integration helper function def modify_integration(integration_name, integration_params): global supabaseClient @@ -4901,12 +4249,7 @@ def modify_integration(integration_name, integration_params): # custom prompt helper function -def register_prompt_template( - model: str, - roles: dict, - initial_prompt_value: str = "", - final_prompt_value: str = "", -): +def register_prompt_template(model: str, roles: dict, initial_prompt_value: str = "", final_prompt_value: str = ""): """ Register a prompt template to follow your custom format for a given model @@ -4920,19 +4263,19 @@ def register_prompt_template( dict: The updated custom prompt dictionary. Example usage: ``` - import litellm + import litellm litellm.register_prompt_template( - model="llama-2", + model="llama-2", initial_prompt_value="You are a good assistant" # [OPTIONAL] - roles={ + roles={ "system": { "pre_message": "[INST] <>\n", # [OPTIONAL] "post_message": "\n<>\n [/INST]\n" # [OPTIONAL] }, - "user": { + "user": { "pre_message": "[INST] ", # [OPTIONAL] "post_message": " [/INST]" # [OPTIONAL] - }, + }, "assistant": { "pre_message": "\n" # [OPTIONAL] "post_message": "\n" # [OPTIONAL] @@ -4946,12 +4289,11 @@ def register_prompt_template( litellm.custom_prompt_dict[model] = { "roles": roles, "initial_prompt_value": initial_prompt_value, - "final_prompt_value": final_prompt_value, + "final_prompt_value": final_prompt_value } return litellm.custom_prompt_dict - -####### DEPRECATED ################ +####### DEPRECATED ################ def get_all_keys(llm_provider=None): @@ -5040,25 +4382,20 @@ def get_model_list(): f"[Non-Blocking Error] get_model_list error - {traceback.format_exc()}" ) - ####### EXCEPTION MAPPING ################ def exception_type( - model, - original_exception, - custom_llm_provider, - completion_kwargs={}, -): + model, + original_exception, + custom_llm_provider, + completion_kwargs={}, + ): global user_logger_fn, liteDebuggerClient exception_mapping_worked = False if litellm.suppress_debug_info is False: - print() # noqa - print( - "\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m" - ) # noqa - print( - "LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'." - ) # noqa - print() # noqa + print() # noqa + print("\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m") # noqa + print("LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'.") # noqa + print() # noqa try: if model: error_str = str(original_exception) @@ -5066,53 +4403,39 @@ def exception_type( exception_type = type(original_exception).__name__ else: exception_type = "" - - if "Request Timeout Error" in error_str or "Request timed out" in error_str: + + if "Request Timeout Error" in error_str or "Request timed out" in error_str: exception_mapping_worked = True raise Timeout( message=f"APITimeoutError - Request timed out", model=model, - llm_provider=custom_llm_provider, + llm_provider=custom_llm_provider ) - if ( - custom_llm_provider == "openai" - or custom_llm_provider == "text-completion-openai" - or custom_llm_provider == "custom_openai" - or custom_llm_provider in litellm.openai_compatible_providers - ): - if ( - "This model's maximum context length is" in error_str - or "Request too large" in error_str - ): + if custom_llm_provider == "openai" or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "custom_openai" or custom_llm_provider in litellm.openai_compatible_providers: + if "This model's maximum context length is" in error_str or "Request too large" in error_str: exception_mapping_worked = True raise ContextWindowExceededError( message=f"OpenAIException - {original_exception.message}", llm_provider="openai", model=model, - response=original_exception.response, + response=original_exception.response ) - elif ( - "invalid_request_error" in error_str - and "model_not_found" in error_str - ): + elif "invalid_request_error" in error_str and "model_not_found" in error_str: exception_mapping_worked = True raise NotFoundError( message=f"OpenAIException - {original_exception.message}", llm_provider="openai", model=model, - response=original_exception.response, + response=original_exception.response ) - elif ( - "invalid_request_error" in error_str - and "Incorrect API key provided" not in error_str - ): + elif "invalid_request_error" in error_str and "Incorrect API key provided" not in error_str: exception_mapping_worked = True raise BadRequestError( message=f"OpenAIException - {original_exception.message}", llm_provider="openai", model=model, - response=original_exception.response, + response=original_exception.response ) elif hasattr(original_exception, "status_code"): exception_mapping_worked = True @@ -5122,7 +4445,7 @@ def exception_type( message=f"OpenAIException - {original_exception.message}", llm_provider="openai", model=model, - response=original_exception.response, + response=original_exception.response ) elif original_exception.status_code == 404: exception_mapping_worked = True @@ -5130,7 +4453,7 @@ def exception_type( message=f"OpenAIException - {original_exception.message}", model=model, llm_provider="openai", - response=original_exception.response, + response=original_exception.response ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -5145,7 +4468,7 @@ def exception_type( message=f"OpenAIException - {original_exception.message}", model=model, llm_provider="openai", - response=original_exception.response, + response=original_exception.response ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -5153,17 +4476,17 @@ def exception_type( message=f"OpenAIException - {original_exception.message}", model=model, llm_provider="openai", - response=original_exception.response, + response=original_exception.response ) - elif original_exception.status_code == 503: + elif original_exception.status_code == 503: exception_mapping_worked = True raise ServiceUnavailableError( message=f"OpenAIException - {original_exception.message}", model=model, llm_provider="openai", - response=original_exception.response, + response=original_exception.response ) - elif original_exception.status_code == 504: # gateway timeout error + elif original_exception.status_code == 504: # gateway timeout error exception_mapping_worked = True raise Timeout( message=f"OpenAIException - {original_exception.message}", @@ -5173,11 +4496,11 @@ def exception_type( else: exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"OpenAIException - {original_exception.message}", llm_provider="openai", model=model, - request=original_exception.request, + request=original_exception.request ) else: # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors @@ -5185,28 +4508,25 @@ def exception_type( __cause__=original_exception.__cause__, llm_provider=custom_llm_provider, model=model, - request=original_exception.request, + request=original_exception.request ) elif custom_llm_provider == "anthropic": # one of the anthropics if hasattr(original_exception, "message"): - if ( - "prompt is too long" in original_exception.message - or "prompt: length" in original_exception.message - ): + if "prompt is too long" in original_exception.message or "prompt: length" in original_exception.message: exception_mapping_worked = True raise ContextWindowExceededError( - message=original_exception.message, + message=original_exception.message, model=model, llm_provider="anthropic", - response=original_exception.response, + response=original_exception.response ) if "Invalid API Key" in original_exception.message: exception_mapping_worked = True raise AuthenticationError( - message=original_exception.message, + message=original_exception.message, model=model, llm_provider="anthropic", - response=original_exception.response, + response=original_exception.response ) if hasattr(original_exception, "status_code"): print_verbose(f"status_code: {original_exception.status_code}") @@ -5216,18 +4536,15 @@ def exception_type( message=f"AnthropicException - {original_exception.message}", llm_provider="anthropic", model=model, - response=original_exception.response, + response=original_exception.response ) - elif ( - original_exception.status_code == 400 - or original_exception.status_code == 413 - ): + elif original_exception.status_code == 400 or original_exception.status_code == 413: exception_mapping_worked = True raise BadRequestError( message=f"AnthropicException - {original_exception.message}", model=model, llm_provider="anthropic", - response=original_exception.response, + response=original_exception.response ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -5235,7 +4552,7 @@ def exception_type( message=f"AnthropicException - {original_exception.message}", model=model, llm_provider="anthropic", - request=original_exception.request, + request=original_exception.request ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -5243,7 +4560,7 @@ def exception_type( message=f"AnthropicException - {original_exception.message}", llm_provider="anthropic", model=model, - response=original_exception.response, + response=original_exception.response ) elif original_exception.status_code == 500: exception_mapping_worked = True @@ -5251,7 +4568,7 @@ def exception_type( message=f"AnthropicException - {original_exception.message}", llm_provider="anthropic", model=model, - response=original_exception.response, + response=original_exception.response ) else: exception_mapping_worked = True @@ -5260,7 +4577,7 @@ def exception_type( message=f"AnthropicException - {original_exception.message}", llm_provider="anthropic", model=model, - request=original_exception.request, + request=original_exception.request ) elif custom_llm_provider == "replicate": if "Incorrect authentication token" in error_str: @@ -5269,7 +4586,7 @@ def exception_type( message=f"ReplicateException - {error_str}", llm_provider="replicate", model=model, - response=original_exception.response, + response=original_exception.response ) elif "input is too long" in error_str: exception_mapping_worked = True @@ -5277,7 +4594,7 @@ def exception_type( message=f"ReplicateException - {error_str}", model=model, llm_provider="replicate", - response=original_exception.response, + response=original_exception.response ) elif exception_type == "ModelError": exception_mapping_worked = True @@ -5285,7 +4602,7 @@ def exception_type( message=f"ReplicateException - {error_str}", model=model, llm_provider="replicate", - response=original_exception.response, + response=original_exception.response ) elif "Request was throttled" in error_str: exception_mapping_worked = True @@ -5293,7 +4610,7 @@ def exception_type( message=f"ReplicateException - {error_str}", llm_provider="replicate", model=model, - response=original_exception.response, + response=original_exception.response ) elif hasattr(original_exception, "status_code"): if original_exception.status_code == 401: @@ -5302,19 +4619,15 @@ def exception_type( message=f"ReplicateException - {original_exception.message}", llm_provider="replicate", model=model, - response=original_exception.response, + response=original_exception.response ) - elif ( - original_exception.status_code == 400 - or original_exception.status_code == 422 - or original_exception.status_code == 413 - ): + elif original_exception.status_code == 400 or original_exception.status_code == 422 or original_exception.status_code == 413: exception_mapping_worked = True raise BadRequestError( message=f"ReplicateException - {original_exception.message}", model=model, llm_provider="replicate", - response=original_exception.response, + response=original_exception.response ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -5322,7 +4635,7 @@ def exception_type( message=f"ReplicateException - {original_exception.message}", model=model, llm_provider="replicate", - request=original_exception.request, + request=original_exception.request ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -5330,7 +4643,7 @@ def exception_type( message=f"ReplicateException - {original_exception.message}", llm_provider="replicate", model=model, - response=original_exception.response, + response=original_exception.response ) elif original_exception.status_code == 500: exception_mapping_worked = True @@ -5338,60 +4651,48 @@ def exception_type( message=f"ReplicateException - {original_exception.message}", llm_provider="replicate", model=model, - response=original_exception.response, + response=original_exception.response ) exception_mapping_worked = True raise APIError( - status_code=500, + status_code=500, message=f"ReplicateException - {str(original_exception)}", llm_provider="replicate", model=model, - request=original_exception.request, + request=original_exception.request ) elif custom_llm_provider == "bedrock": - if ( - "too many tokens" in error_str - or "expected maxLength:" in error_str - or "Input is too long" in error_str - or "Too many input tokens" in error_str - ): + if "too many tokens" in error_str or "expected maxLength:" in error_str or "Input is too long" in error_str or "Too many input tokens" in error_str: exception_mapping_worked = True raise ContextWindowExceededError( message=f"BedrockException: Context Window Error - {error_str}", - model=model, + model=model, llm_provider="bedrock", - response=original_exception.response, + response=original_exception.response ) if "Malformed input request" in error_str: exception_mapping_worked = True raise BadRequestError( - message=f"BedrockException - {error_str}", - model=model, + message=f"BedrockException - {error_str}", + model=model, llm_provider="bedrock", - response=original_exception.response, + response=original_exception.response ) - if ( - "Unable to locate credentials" in error_str - or "The security token included in the request is invalid" - in error_str - ): + if "Unable to locate credentials" in error_str or "The security token included in the request is invalid" in error_str: exception_mapping_worked = True raise AuthenticationError( - message=f"BedrockException Invalid Authentication - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, + message=f"BedrockException Invalid Authentication - {error_str}", + model=model, + llm_provider="bedrock", + response=original_exception.response ) - if ( - "throttlingException" in error_str - or "ThrottlingException" in error_str - ): + if "throttlingException" in error_str or "ThrottlingException" in error_str: exception_mapping_worked = True raise RateLimitError( - message=f"BedrockException: Rate Limit Error - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, + message=f"BedrockException: Rate Limit Error - {error_str}", + model=model, + llm_provider="bedrock", + response=original_exception.response ) if hasattr(original_exception, "status_code"): if original_exception.status_code == 500: @@ -5400,7 +4701,7 @@ def exception_type( message=f"BedrockException - {original_exception.message}", llm_provider="bedrock", model=model, - response=original_exception.response, + response=original_exception.response ) elif original_exception.status_code == 401: exception_mapping_worked = True @@ -5408,55 +4709,49 @@ def exception_type( message=f"BedrockException - {original_exception.message}", llm_provider="bedrock", model=model, - response=original_exception.response, + response=original_exception.response ) - elif custom_llm_provider == "sagemaker": + elif custom_llm_provider == "sagemaker": if "Unable to locate credentials" in error_str: exception_mapping_worked = True raise BadRequestError( - message=f"SagemakerException - {error_str}", - model=model, + message=f"SagemakerException - {error_str}", + model=model, llm_provider="sagemaker", - response=original_exception.response, + response=original_exception.response ) - elif ( - "Input validation error: `best_of` must be > 0 and <= 2" - in error_str - ): + elif "Input validation error: `best_of` must be > 0 and <= 2" in error_str: exception_mapping_worked = True raise BadRequestError( - message=f"SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints", - model=model, + message=f"SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints", + model=model, llm_provider="sagemaker", - response=original_exception.response, + response=original_exception.response ) elif custom_llm_provider == "vertex_ai": - if ( - "Vertex AI API has not been used in project" in error_str - or "Unable to find your project" in error_str - ): + if "Vertex AI API has not been used in project" in error_str or "Unable to find your project" in error_str: exception_mapping_worked = True raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, + message=f"VertexAIException - {error_str}", + model=model, llm_provider="vertex_ai", - response=original_exception.response, + response=original_exception.response ) - elif "403" in error_str: + elif "403" in error_str: exception_mapping_worked = True - raise U( - message=f"VertexAIException - {error_str}", - model=model, + raise BadRequestError( + message=f"VertexAIException - {error_str}", + model=model, llm_provider="vertex_ai", - response=original_exception.response, + response=original_exception.response ) elif "The response was blocked." in error_str: exception_mapping_worked = True raise UnprocessableEntityError( - message=f"VertexAIException - {error_str}", - model=model, + message=f"VertexAIException - {error_str}", + model=model, llm_provider="vertex_ai", - response=original_exception.response, + response=original_exception.response ) if hasattr(original_exception, "status_code"): if original_exception.status_code == 400: @@ -5465,16 +4760,16 @@ def exception_type( message=f"VertexAIException - {error_str}", model=model, llm_provider="vertex_ai", - response=original_exception.response, + response=original_exception.response ) - if original_exception.status_code == 500: + if original_exception.status_code == 500: exception_mapping_worked = True raise APIError( message=f"VertexAIException - {error_str}", status_code=500, model=model, llm_provider="vertex_ai", - request=original_exception.request, + request=original_exception.request ) elif custom_llm_provider == "palm": if "503 Getting metadata" in error_str: @@ -5482,10 +4777,10 @@ def exception_type( # 503 Getting metadata from plugin failed with error: Reauthentication is needed. Please run `gcloud auth application-default login` to reauthenticate. exception_mapping_worked = True raise BadRequestError( - message=f"PalmException - Invalid api key", - model=model, + message=f"PalmException - Invalid api key", + model=model, llm_provider="palm", - response=original_exception.response, + response=original_exception.response ) if "400 Request payload size exceeds" in error_str: exception_mapping_worked = True @@ -5493,7 +4788,7 @@ def exception_type( message=f"PalmException - {error_str}", model=model, llm_provider="palm", - response=original_exception.response, + response=original_exception.response ) if hasattr(original_exception, "status_code"): if original_exception.status_code == 400: @@ -5502,7 +4797,7 @@ def exception_type( message=f"PalmException - {error_str}", model=model, llm_provider="palm", - response=original_exception.response, + response=original_exception.response ) # Dailed: Error occurred: 400 Request payload size exceeds the limit: 20000 bytes elif custom_llm_provider == "cohere": # Cohere @@ -5515,7 +4810,7 @@ def exception_type( message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response, + response=original_exception.response ) elif "too many tokens" in error_str: exception_mapping_worked = True @@ -5523,19 +4818,16 @@ def exception_type( message=f"CohereException - {original_exception.message}", model=model, llm_provider="cohere", - response=original_exception.response, + response=original_exception.response ) elif hasattr(original_exception, "status_code"): - if ( - original_exception.status_code == 400 - or original_exception.status_code == 498 - ): + if original_exception.status_code == 400 or original_exception.status_code == 498: exception_mapping_worked = True raise BadRequestError( message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response, + response=original_exception.response ) elif original_exception.status_code == 500: exception_mapping_worked = True @@ -5543,7 +4835,7 @@ def exception_type( message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response, + response=original_exception.response ) elif ( "CohereConnectionError" in exception_type @@ -5553,7 +4845,7 @@ def exception_type( message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response, + response=original_exception.response ) elif "invalid type:" in error_str: exception_mapping_worked = True @@ -5561,7 +4853,7 @@ def exception_type( message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response, + response=original_exception.response ) elif "Unexpected server error" in error_str: exception_mapping_worked = True @@ -5569,17 +4861,17 @@ def exception_type( message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response, + response=original_exception.response ) else: if hasattr(original_exception, "status_code"): exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - request=original_exception.request, + request=original_exception.request ) raise original_exception elif custom_llm_provider == "huggingface": @@ -5589,15 +4881,15 @@ def exception_type( message=error_str, model=model, llm_provider="huggingface", - response=original_exception.response, + response=original_exception.response ) elif "A valid user token is required" in error_str: exception_mapping_worked = True raise BadRequestError( - message=error_str, + message=error_str, llm_provider="huggingface", model=model, - response=original_exception.response, + response=original_exception.response ) if hasattr(original_exception, "status_code"): if original_exception.status_code == 401: @@ -5606,7 +4898,7 @@ def exception_type( message=f"HuggingfaceException - {original_exception.message}", llm_provider="huggingface", model=model, - response=original_exception.response, + response=original_exception.response ) elif original_exception.status_code == 400: exception_mapping_worked = True @@ -5614,7 +4906,7 @@ def exception_type( message=f"HuggingfaceException - {original_exception.message}", model=model, llm_provider="huggingface", - response=original_exception.response, + response=original_exception.response ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -5622,7 +4914,7 @@ def exception_type( message=f"HuggingfaceException - {original_exception.message}", model=model, llm_provider="huggingface", - request=original_exception.request, + request=original_exception.request ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -5630,16 +4922,16 @@ def exception_type( message=f"HuggingfaceException - {original_exception.message}", llm_provider="huggingface", model=model, - response=original_exception.response, + response=original_exception.response ) else: exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"HuggingfaceException - {original_exception.message}", llm_provider="huggingface", model=model, - request=original_exception.request, + request=original_exception.request ) elif custom_llm_provider == "ai21": if hasattr(original_exception, "message"): @@ -5649,15 +4941,15 @@ def exception_type( message=f"AI21Exception - {original_exception.message}", model=model, llm_provider="ai21", - response=original_exception.response, + response=original_exception.response ) - if "Bad or missing API token." in original_exception.message: + if "Bad or missing API token." in original_exception.message: exception_mapping_worked = True raise BadRequestError( message=f"AI21Exception - {original_exception.message}", model=model, llm_provider="ai21", - response=original_exception.response, + response=original_exception.response ) if hasattr(original_exception, "status_code"): if original_exception.status_code == 401: @@ -5666,7 +4958,7 @@ def exception_type( message=f"AI21Exception - {original_exception.message}", llm_provider="ai21", model=model, - response=original_exception.response, + response=original_exception.response ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -5674,7 +4966,7 @@ def exception_type( message=f"AI21Exception - {original_exception.message}", model=model, llm_provider="ai21", - request=original_exception.request, + request=original_exception.request ) if original_exception.status_code == 422: exception_mapping_worked = True @@ -5682,7 +4974,7 @@ def exception_type( message=f"AI21Exception - {original_exception.message}", model=model, llm_provider="ai21", - response=original_exception.response, + response=original_exception.response ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -5690,16 +4982,16 @@ def exception_type( message=f"AI21Exception - {original_exception.message}", llm_provider="ai21", model=model, - response=original_exception.response, + response=original_exception.response ) else: exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"AI21Exception - {original_exception.message}", llm_provider="ai21", model=model, - request=original_exception.request, + request=original_exception.request ) elif custom_llm_provider == "nlp_cloud": if "detail" in error_str: @@ -5709,7 +5001,7 @@ def exception_type( message=f"NLPCloudException - {error_str}", model=model, llm_provider="nlp_cloud", - response=original_exception.response, + response=original_exception.response ) elif "value is not a valid" in error_str: exception_mapping_worked = True @@ -5717,180 +5009,140 @@ def exception_type( message=f"NLPCloudException - {error_str}", model=model, llm_provider="nlp_cloud", - response=original_exception.response, + response=original_exception.response ) - else: + else: exception_mapping_worked = True raise APIError( status_code=500, message=f"NLPCloudException - {error_str}", model=model, llm_provider="nlp_cloud", - request=original_exception.request, + request=original_exception.request ) - if hasattr( - original_exception, "status_code" - ): # https://docs.nlpcloud.com/?shell#errors - if ( - original_exception.status_code == 400 - or original_exception.status_code == 406 - or original_exception.status_code == 413 - or original_exception.status_code == 422 - ): + if hasattr(original_exception, "status_code"): # https://docs.nlpcloud.com/?shell#errors + if original_exception.status_code == 400 or original_exception.status_code == 406 or original_exception.status_code == 413 or original_exception.status_code == 422: exception_mapping_worked = True raise BadRequestError( message=f"NLPCloudException - {original_exception.message}", llm_provider="nlp_cloud", model=model, - response=original_exception.response, + response=original_exception.response ) - elif ( - original_exception.status_code == 401 - or original_exception.status_code == 403 - ): + elif original_exception.status_code == 401 or original_exception.status_code == 403: exception_mapping_worked = True raise AuthenticationError( message=f"NLPCloudException - {original_exception.message}", llm_provider="nlp_cloud", model=model, - response=original_exception.response, + response=original_exception.response ) - elif ( - original_exception.status_code == 522 - or original_exception.status_code == 524 - ): + elif original_exception.status_code == 522 or original_exception.status_code == 524: exception_mapping_worked = True raise Timeout( message=f"NLPCloudException - {original_exception.message}", model=model, llm_provider="nlp_cloud", - request=original_exception.request, + request=original_exception.request ) - elif ( - original_exception.status_code == 429 - or original_exception.status_code == 402 - ): + elif original_exception.status_code == 429 or original_exception.status_code == 402: exception_mapping_worked = True raise RateLimitError( message=f"NLPCloudException - {original_exception.message}", llm_provider="nlp_cloud", model=model, - response=original_exception.response, + response=original_exception.response ) - elif ( - original_exception.status_code == 500 - or original_exception.status_code == 503 - ): + elif original_exception.status_code == 500 or original_exception.status_code == 503: exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"NLPCloudException - {original_exception.message}", llm_provider="nlp_cloud", model=model, - request=original_exception.request, + request=original_exception.request ) - elif ( - original_exception.status_code == 504 - or original_exception.status_code == 520 - ): + elif original_exception.status_code == 504 or original_exception.status_code == 520: exception_mapping_worked = True raise ServiceUnavailableError( message=f"NLPCloudException - {original_exception.message}", model=model, llm_provider="nlp_cloud", - response=original_exception.response, + response=original_exception.response ) else: exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"NLPCloudException - {original_exception.message}", llm_provider="nlp_cloud", model=model, - request=original_exception.request, + request=original_exception.request ) elif custom_llm_provider == "together_ai": import json - try: error_response = json.loads(error_str) except: error_response = {"error": error_str} - if ( - "error" in error_response - and "`inputs` tokens + `max_new_tokens` must be <=" - in error_response["error"] - ): + if "error" in error_response and "`inputs` tokens + `max_new_tokens` must be <=" in error_response["error"]: exception_mapping_worked = True raise ContextWindowExceededError( message=f"TogetherAIException - {error_response['error']}", model=model, llm_provider="together_ai", - response=original_exception.response, + response=original_exception.response ) - elif ( - "error" in error_response - and "invalid private key" in error_response["error"] - ): + elif "error" in error_response and "invalid private key" in error_response["error"]: exception_mapping_worked = True raise AuthenticationError( message=f"TogetherAIException - {error_response['error']}", llm_provider="together_ai", model=model, - response=original_exception.response, + response=original_exception.response ) - elif ( - "error" in error_response - and "INVALID_ARGUMENT" in error_response["error"] - ): + elif "error" in error_response and "INVALID_ARGUMENT" in error_response["error"]: exception_mapping_worked = True raise BadRequestError( message=f"TogetherAIException - {error_response['error']}", model=model, llm_provider="together_ai", - response=original_exception.response, + response=original_exception.response ) - - elif ( - "error" in error_response - and "API key doesn't match expected format." - in error_response["error"] - ): + + elif "error" in error_response and "API key doesn't match expected format." in error_response["error"]: exception_mapping_worked = True raise BadRequestError( message=f"TogetherAIException - {error_response['error']}", model=model, llm_provider="together_ai", - response=original_exception.response, + response=original_exception.response ) - elif ( - "error_type" in error_response - and error_response["error_type"] == "validation" - ): + elif "error_type" in error_response and error_response["error_type"] == "validation": exception_mapping_worked = True raise BadRequestError( message=f"TogetherAIException - {error_response['error']}", model=model, llm_provider="together_ai", - response=original_exception.response, + response=original_exception.response ) if hasattr(original_exception, "status_code"): if original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"TogetherAIException - {original_exception.message}", - model=model, - llm_provider="together_ai", - request=original_exception.request, - ) + exception_mapping_worked = True + raise Timeout( + message=f"TogetherAIException - {original_exception.message}", + model=model, + llm_provider="together_ai", + request=original_exception.request + ) elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"TogetherAIException - {original_exception.message}", - llm_provider="together_ai", - model=model, - response=original_exception.response, - ) + exception_mapping_worked = True + raise RateLimitError( + message=f"TogetherAIException - {original_exception.message}", + llm_provider="together_ai", + model=model, + response=original_exception.response + ) elif original_exception.status_code == 524: exception_mapping_worked = True raise Timeout( @@ -5898,34 +5150,31 @@ def exception_type( llm_provider="together_ai", model=model, ) - else: + else: exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"TogetherAIException - {original_exception.message}", llm_provider="together_ai", model=model, - request=original_exception.request, + request=original_exception.request ) elif custom_llm_provider == "aleph_alpha": - if ( - "This is longer than the model's maximum context length" - in error_str - ): + if "This is longer than the model's maximum context length" in error_str: exception_mapping_worked = True raise ContextWindowExceededError( message=f"AlephAlphaException - {original_exception.message}", - llm_provider="aleph_alpha", + llm_provider="aleph_alpha", model=model, - response=original_exception.response, + response=original_exception.response ) elif "InvalidToken" in error_str or "No token provided" in error_str: exception_mapping_worked = True raise BadRequestError( message=f"AlephAlphaException - {original_exception.message}", - llm_provider="aleph_alpha", + llm_provider="aleph_alpha", model=model, - response=original_exception.response, + response=original_exception.response ) elif hasattr(original_exception, "status_code"): print_verbose(f"status code: {original_exception.status_code}") @@ -5934,7 +5183,7 @@ def exception_type( raise AuthenticationError( message=f"AlephAlphaException - {original_exception.message}", llm_provider="aleph_alpha", - model=model, + model=model ) elif original_exception.status_code == 400: exception_mapping_worked = True @@ -5942,7 +5191,7 @@ def exception_type( message=f"AlephAlphaException - {original_exception.message}", llm_provider="aleph_alpha", model=model, - response=original_exception.response, + response=original_exception.response ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -5950,7 +5199,7 @@ def exception_type( message=f"AlephAlphaException - {original_exception.message}", llm_provider="aleph_alpha", model=model, - response=original_exception.response, + response=original_exception.response ) elif original_exception.status_code == 500: exception_mapping_worked = True @@ -5958,35 +5207,30 @@ def exception_type( message=f"AlephAlphaException - {original_exception.message}", llm_provider="aleph_alpha", model=model, - response=original_exception.response, + response=original_exception.response ) raise original_exception raise original_exception elif custom_llm_provider == "ollama": - if "no attribute 'async_get_ollama_response_stream" in error_str: - exception_mapping_worked = True - raise ImportError( - "Import error - trying to use async for ollama. import async_generator failed. Try 'pip install async_generator'" - ) if isinstance(original_exception, dict): error_str = original_exception.get("error", "") - else: + else: error_str = str(original_exception) if "no such file or directory" in error_str: exception_mapping_worked = True raise BadRequestError( - message=f"OllamaException: Invalid Model/Model not loaded - {original_exception}", - model=model, - llm_provider="ollama", - response=original_exception.response, - ) - elif "Failed to establish a new connection" in error_str: + message=f"OllamaException: Invalid Model/Model not loaded - {original_exception}", + model=model, + llm_provider="ollama", + response=original_exception.response + ) + elif "Failed to establish a new connection" in error_str: exception_mapping_worked = True raise ServiceUnavailableError( message=f"OllamaException: {original_exception}", - llm_provider="ollama", + llm_provider="ollama", model=model, - response=original_exception.response, + response=original_exception.response ) elif "Invalid response object from API" in error_str: exception_mapping_worked = True @@ -5994,7 +5238,7 @@ def exception_type( message=f"OllamaException: {original_exception}", llm_provider="ollama", model=model, - response=original_exception.response, + response=original_exception.response ) elif custom_llm_provider == "vllm": if hasattr(original_exception, "status_code"): @@ -6004,16 +5248,16 @@ def exception_type( message=f"VLLMException - {original_exception.message}", llm_provider="vllm", model=model, - request=original_exception.request, + request=original_exception.request ) - elif custom_llm_provider == "azure": + elif custom_llm_provider == "azure": if "This model's maximum context length is" in error_str: exception_mapping_worked = True raise ContextWindowExceededError( message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, - response=original_exception.response, + response=original_exception.response ) elif "DeploymentNotFound" in error_str: exception_mapping_worked = True @@ -6021,7 +5265,7 @@ def exception_type( message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, - response=original_exception.response, + response=original_exception.response ) elif "invalid_request_error" in error_str: exception_mapping_worked = True @@ -6029,7 +5273,7 @@ def exception_type( message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, - response=original_exception.response, + response=original_exception.response ) elif hasattr(original_exception, "status_code"): exception_mapping_worked = True @@ -6039,7 +5283,7 @@ def exception_type( message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, - response=original_exception.response, + response=original_exception.response ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -6047,7 +5291,7 @@ def exception_type( message=f"AzureException - {original_exception.message}", model=model, llm_provider="azure", - request=original_exception.request, + request=original_exception.request ) if original_exception.status_code == 422: exception_mapping_worked = True @@ -6055,7 +5299,7 @@ def exception_type( message=f"AzureException - {original_exception.message}", model=model, llm_provider="azure", - response=original_exception.response, + response=original_exception.response ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -6063,16 +5307,16 @@ def exception_type( message=f"AzureException - {original_exception.message}", model=model, llm_provider="azure", - response=original_exception.response, + response=original_exception.response ) else: exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, - request=original_exception.request, + request=original_exception.request ) else: # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors @@ -6080,36 +5324,31 @@ def exception_type( __cause__=original_exception.__cause__, llm_provider="azure", model=model, - request=original_exception.request, + request=original_exception.request ) - if ( - "BadRequestError.__init__() missing 1 required positional argument: 'param'" - in str(original_exception) - ): # deal with edge-case invalid request error bug in openai-python sdk + if "BadRequestError.__init__() missing 1 required positional argument: 'param'" in str(original_exception): # deal with edge-case invalid request error bug in openai-python sdk exception_mapping_worked = True raise BadRequestError( message=f"OpenAIException: This can happen due to missing AZURE_API_VERSION: {str(original_exception)}", - model=model, + model=model, llm_provider=custom_llm_provider, - response=original_exception.response, + response=original_exception.response ) - else: # ensure generic errors always return APIConnectionError= + else: # ensure generic errors always return APIConnectionError= exception_mapping_worked = True if hasattr(original_exception, "request"): raise APIConnectionError( message=f"{str(original_exception)}", llm_provider=custom_llm_provider, model=model, - request=original_exception.request, + request=original_exception.request ) - else: - raise APIConnectionError( + else: + raise APIConnectionError( message=f"{str(original_exception)}", llm_provider=custom_llm_provider, model=model, - request=httpx.Request( - method="POST", url="https://api.openai.com/v1/" - ), # stub the request + request= httpx.Request(method="POST", url="https://api.openai.com/v1/") # stub the request ) except Exception as e: # LOGGING @@ -6143,10 +5382,9 @@ def safe_crash_reporting(model=None, exception=None, custom_llm_provider=None): executor.submit(litellm_telemetry, data) # threading.Thread(target=litellm_telemetry, args=(data,), daemon=True).start() - def get_or_generate_uuid(): temp_dir = os.path.join(os.path.abspath(os.sep), "tmp") - uuid_file = os.path.join(temp_dir, "litellm_uuid.txt") + uuid_file = os.path.join(temp_dir, "litellm_uuid.txt") try: # Try to open the file and load the UUID with open(uuid_file, "r") as file: @@ -6158,19 +5396,19 @@ def get_or_generate_uuid(): except FileNotFoundError: # Generate a new UUID if the file doesn't exist or is empty - try: + try: new_uuid = uuid.uuid4() uuid_value = str(new_uuid) with open(uuid_file, "w") as file: file.write(uuid_value) - except: # if writing to tmp/litellm_uuid.txt then retry writing to litellm_uuid.txt + except: # if writing to tmp/litellm_uuid.txt then retry writing to litellm_uuid.txt try: new_uuid = uuid.uuid4() uuid_value = str(new_uuid) with open("litellm_uuid.txt", "w") as file: file.write(uuid_value) - except: # if this 3rd attempt fails just pass - # Good first issue for someone to improve this function :) + except: # if this 3rd attempt fails just pass + # Good first issue for someone to improve this function :) return except: # [Non-Blocking Error] @@ -6187,13 +5425,17 @@ def litellm_telemetry(data): uuid_value = str(uuid.uuid4()) try: # Prepare the data to send to litellm logging api - try: + try: pkg_version = importlib.metadata.version("litellm") except: pkg_version = None if "model" not in data: data["model"] = None - payload = {"uuid": uuid_value, "data": data, "version:": pkg_version} + payload = { + "uuid": uuid_value, + "data": data, + "version:": pkg_version + } # Make the POST request to litellm logging api response = requests.post( "https://litellm-logging.onrender.com/logging", @@ -6205,33 +5447,29 @@ def litellm_telemetry(data): # [Non-Blocking Error] return - ######### Secret Manager ############################ # checks if user has passed in a secret manager client # if passed in then checks the secret there -def get_secret(secret_name: str, default_value: Optional[str] = None): - if secret_name.startswith("os.environ/"): +def get_secret(secret_name: str, default_value: Optional[str]=None): + if secret_name.startswith("os.environ/"): secret_name = secret_name.replace("os.environ/", "") - try: + try: if litellm.secret_manager_client is not None: try: client = litellm.secret_manager_client - if ( - type(client).__module__ + "." + type(client).__name__ - == "azure.keyvault.secrets._client.SecretClient" - ): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient + if type(client).__module__ + '.' + type(client).__name__ == 'azure.keyvault.secrets._client.SecretClient': # support Azure Secret Client - from azure.keyvault.secrets import SecretClient secret = retrieved_secret = client.get_secret(secret_name).value - else: # assume the default is infisicial client + else: # assume the default is infisicial client secret = client.get_secret(secret_name).secret_value - except: # check if it's in os.environ + except: # check if it's in os.environ secret = os.environ.get(secret_name) return secret else: return os.environ.get(secret_name) - except Exception as e: - if default_value is not None: + except Exception as e: + if default_value is not None: return default_value - else: + else: raise e @@ -6239,9 +5477,7 @@ def get_secret(secret_name: str, default_value: Optional[str] = None): # wraps the completion stream to return the correct format for the model # replicate/anthropic/cohere class CustomStreamWrapper: - def __init__( - self, completion_stream, model, custom_llm_provider=None, logging_obj=None - ): + def __init__(self, completion_stream, model, custom_llm_provider=None, logging_obj=None): self.model = model self.custom_llm_provider = custom_llm_provider self.logging_obj = logging_obj @@ -6249,7 +5485,7 @@ class CustomStreamWrapper: self.sent_first_chunk = False self.sent_last_chunk = False self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "", ""] - self.holding_chunk = "" + self.holding_chunk = "" self.complete_response = "" def __iter__(self): @@ -6258,111 +5494,94 @@ class CustomStreamWrapper: def __aiter__(self): return self - def process_chunk(self, chunk: str): + def process_chunk(self, chunk: str): """ NLP Cloud streaming returns the entire response, for each chunk. Process this, to only return the delta. """ - try: + try: chunk = chunk.strip() self.complete_response = self.complete_response.strip() - if chunk.startswith(self.complete_response): + if chunk.startswith(self.complete_response): # Remove last_sent_chunk only if it appears at the start of the new chunk - chunk = chunk[len(self.complete_response) :] + chunk = chunk[len(self.complete_response):] self.complete_response += chunk - return chunk - except Exception as e: + return chunk + except Exception as e: raise e - - def check_special_tokens(self, chunk: str, finish_reason: Optional[str]): + + def check_special_tokens(self, chunk: str, finish_reason: Optional[str]): hold = False - if finish_reason: - for token in self.special_tokens: + if finish_reason: + for token in self.special_tokens: if token in chunk: - chunk = chunk.replace(token, "") + chunk = chunk.replace(token, "") return hold, chunk - + if self.sent_first_chunk is True: return hold, chunk curr_chunk = self.holding_chunk + chunk curr_chunk = curr_chunk.strip() - for token in self.special_tokens: - if len(curr_chunk) < len(token) and curr_chunk in token: + for token in self.special_tokens: + if len(curr_chunk) < len(token) and curr_chunk in token: hold = True elif len(curr_chunk) >= len(token): if token in curr_chunk: self.holding_chunk = curr_chunk.replace(token, "") hold = True - else: + else: pass - - if hold is False: # reset - self.holding_chunk = "" + + if hold is False: # reset + self.holding_chunk = "" return hold, curr_chunk + def handle_anthropic_chunk(self, chunk): str_line = chunk.decode("utf-8") # Convert bytes to string - text = "" + text = "" is_finished = False finish_reason = None if str_line.startswith("data:"): data_json = json.loads(str_line[5:]) - text = data_json.get("completion", "") - if data_json.get("stop_reason", None): + text = data_json.get("completion", "") + if data_json.get("stop_reason", None): is_finished = True finish_reason = data_json["stop_reason"] - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} elif "error" in str_line: raise ValueError(f"Unable to parse response. Original response: {str_line}") else: - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} def handle_together_ai_chunk(self, chunk): chunk = chunk.decode("utf-8") - text = "" + text = "" is_finished = False finish_reason = None - if "text" in chunk: + if "text" in chunk: text_index = chunk.find('"text":"') # this checks if text: exists text_start = text_index + len('"text":"') text_end = chunk.find('"}', text_start) if text_index != -1 and text_end != -1: extracted_text = chunk[text_start:text_end] text = extracted_text - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} elif "[DONE]" in chunk: return {"text": text, "is_finished": True, "finish_reason": "stop"} elif "error" in chunk: raise ValueError(chunk) else: - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} def handle_huggingface_chunk(self, chunk): try: if type(chunk) != str: - chunk = chunk.decode( - "utf-8" - ) # DO NOT REMOVE this: This is required for HF inference API + Streaming - text = "" + chunk = chunk.decode("utf-8") # DO NOT REMOVE this: This is required for HF inference API + Streaming + text = "" is_finished = False finish_reason = "" print_verbose(f"chunk: {chunk}") @@ -6371,72 +5590,52 @@ class CustomStreamWrapper: print_verbose(f"data json: {data_json}") if "token" in data_json and "text" in data_json["token"]: text = data_json["token"]["text"] - if data_json.get("details", False) and data_json["details"].get( - "finish_reason", False - ): + if data_json.get("details", False) and data_json["details"].get("finish_reason", False): is_finished = True finish_reason = data_json["details"]["finish_reason"] - elif data_json.get( - "generated_text", False - ): # if full generated text exists, then stream is complete - text = "" # don't return the final bos token + elif data_json.get("generated_text", False): # if full generated text exists, then stream is complete + text = "" # don't return the final bos token is_finished = True finish_reason = "stop" - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } - elif "error" in chunk: + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + elif "error" in chunk: raise ValueError(chunk) - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } - except Exception as e: + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + except Exception as e: traceback.print_exc() # raise(e) - - def handle_ai21_chunk(self, chunk): # fake streaming + + def handle_ai21_chunk(self, chunk): # fake streaming chunk = chunk.decode("utf-8") data_json = json.loads(chunk) try: text = data_json["completions"][0]["data"]["text"] is_finished = True finish_reason = "stop" - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} except: raise ValueError(f"Unable to parse response. Original response: {chunk}") - - def handle_maritalk_chunk(self, chunk): # fake streaming + + def handle_maritalk_chunk(self, chunk): # fake streaming chunk = chunk.decode("utf-8") data_json = json.loads(chunk) try: text = data_json["answer"] is_finished = True finish_reason = "stop" - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} except: raise ValueError(f"Unable to parse response. Original response: {chunk}") - + def handle_nlp_cloud_chunk(self, chunk): - text = "" + text = "" is_finished = False finish_reason = "" try: if "dolphin" in self.model: chunk = self.process_chunk(chunk=chunk) - else: + else: data_json = json.loads(chunk) chunk = data_json["generated_text"] text = chunk @@ -6444,14 +5643,10 @@ class CustomStreamWrapper: text = text.replace("[DONE]", "") is_finished = True finish_reason = "stop" - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} except Exception as e: raise ValueError(f"Unable to parse response. Original response: {chunk}") - + def handle_aleph_alpha_chunk(self, chunk): chunk = chunk.decode("utf-8") data_json = json.loads(chunk) @@ -6459,36 +5654,28 @@ class CustomStreamWrapper: text = data_json["completions"][0]["completion"] is_finished = True finish_reason = "stop" - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} except: raise ValueError(f"Unable to parse response. Original response: {chunk}") - + def handle_cohere_chunk(self, chunk): chunk = chunk.decode("utf-8") data_json = json.loads(chunk) try: - text = "" + text = "" is_finished = False finish_reason = "" - if "text" in data_json: + if "text" in data_json: text = data_json["text"] - elif "is_finished" in data_json: + elif "is_finished" in data_json: is_finished = data_json["is_finished"] finish_reason = data_json["finish_reason"] - else: + else: raise Exception(data_json) - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} except: raise ValueError(f"Unable to parse response. Original response: {chunk}") - + def handle_azure_chunk(self, chunk): is_finished = False finish_reason = "" @@ -6498,92 +5685,72 @@ class CustomStreamWrapper: text = "" is_finished = True finish_reason = "stop" - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} elif chunk.startswith("data:"): - data_json = json.loads(chunk[5:]) # chunk.startswith("data:"): + data_json = json.loads(chunk[5:]) # chunk.startswith("data:"): try: - if len(data_json["choices"]) > 0: - text = data_json["choices"][0]["delta"].get("content", "") - if data_json["choices"][0].get("finish_reason", None): + if len(data_json["choices"]) > 0: + text = data_json["choices"][0]["delta"].get("content", "") + if data_json["choices"][0].get("finish_reason", None): is_finished = True finish_reason = data_json["choices"][0]["finish_reason"] - print_verbose( - f"text: {text}; is_finished: {is_finished}; finish_reason: {finish_reason}" - ) - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } + print_verbose(f"text: {text}; is_finished: {is_finished}; finish_reason: {finish_reason}") + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} except: - raise ValueError( - f"Unable to parse response. Original response: {chunk}" - ) + raise ValueError(f"Unable to parse response. Original response: {chunk}") elif "error" in chunk: raise ValueError(f"Unable to parse response. Original response: {chunk}") else: - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} def handle_replicate_chunk(self, chunk): try: - text = "" + text = "" is_finished = False finish_reason = "" - if "output" in chunk: - text = chunk["output"] - if "status" in chunk: + if "output" in chunk: + text = chunk['output'] + if "status" in chunk: if chunk["status"] == "succeeded": is_finished = True finish_reason = "stop" - elif chunk.get("error", None): + elif chunk.get("error", None): raise Exception(chunk["error"]) - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} except: raise ValueError(f"Unable to parse response. Original response: {chunk}") - - def handle_openai_chat_completion_chunk(self, chunk): - try: + + def handle_openai_chat_completion_chunk(self, chunk): + try: print_verbose(f"\nRaw OpenAI Chunk\n{chunk}\n") str_line = chunk - text = "" + text = "" is_finished = False finish_reason = None - original_chunk = None # this is used for function/tool calling - if len(str_line.choices) > 0: + original_chunk = None # this is used for function/tool calling + if len(str_line.choices) > 0: if str_line.choices[0].delta.content is not None: text = str_line.choices[0].delta.content - else: # function/tool calling chunk - when content is None. in this case we just return the original chunk from openai + else: # function/tool calling chunk - when content is None. in this case we just return the original chunk from openai original_chunk = str_line if str_line.choices[0].finish_reason: is_finished = True finish_reason = str_line.choices[0].finish_reason return { - "text": text, - "is_finished": is_finished, + "text": text, + "is_finished": is_finished, "finish_reason": finish_reason, - "original_chunk": str_line, + "original_chunk": str_line } except Exception as e: traceback.print_exc() raise e def handle_openai_text_completion_chunk(self, chunk): - try: + try: str_line = chunk - text = "" + text = "" is_finished = False finish_reason = None print_verbose(f"str_line: {str_line}") @@ -6591,36 +5758,20 @@ class CustomStreamWrapper: text = "" is_finished = True finish_reason = "stop" - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} elif str_line.startswith("data:"): data_json = json.loads(str_line[5:]) print_verbose(f"delta content: {data_json}") - text = data_json["choices"][0].get("text", "") - if data_json["choices"][0].get("finish_reason", None): + text = data_json["choices"][0].get("text", "") + if data_json["choices"][0].get("finish_reason", None): is_finished = True finish_reason = data_json["choices"][0]["finish_reason"] - print_verbose( - f"text: {text}; is_finished: {is_finished}; finish_reason: {finish_reason}" - ) - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } + print_verbose(f"text: {text}; is_finished: {is_finished}; finish_reason: {finish_reason}") + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} elif "error" in str_line: - raise ValueError( - f"Unable to parse response. Original response: {str_line}" - ) + raise ValueError(f"Unable to parse response. Original response: {str_line}") else: - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} except Exception as e: traceback.print_exc() @@ -6638,22 +5789,14 @@ class CustomStreamWrapper: return "" data_json = json.loads(chunk) if "model_output" in data_json: - if ( - isinstance(data_json["model_output"], dict) - and "data" in data_json["model_output"] - and isinstance(data_json["model_output"]["data"], list) - ): + if isinstance(data_json["model_output"], dict) and "data" in data_json["model_output"] and isinstance(data_json["model_output"]["data"], list): return data_json["model_output"]["data"][0] elif isinstance(data_json["model_output"], str): return data_json["model_output"] - elif "completion" in data_json and isinstance( - data_json["completion"], str - ): + elif "completion" in data_json and isinstance(data_json["completion"], str): return data_json["completion"] else: - raise ValueError( - f"Unable to parse response. Original response: {chunk}" - ) + raise ValueError(f"Unable to parse response. Original response: {chunk}") else: return "" else: @@ -6662,57 +5805,50 @@ class CustomStreamWrapper: traceback.print_exc() return "" - def handle_ollama_stream(self, chunk): - try: + def handle_ollama_stream(self, chunk): + try: json_chunk = json.loads(chunk) - if "error" in json_chunk: + if "error" in json_chunk: raise Exception(f"Ollama Error - {json_chunk}") - - text = "" + + text = "" is_finished = False finish_reason = None if json_chunk["done"] == True: text = "" is_finished = True finish_reason = "stop" - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} elif json_chunk["response"]: print_verbose(f"delta content: {json_chunk}") text = json_chunk["response"] - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } - else: + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + else: raise Exception(f"Ollama Error - {json_chunk}") - except Exception as e: + except Exception as e: raise e + def handle_bedrock_stream(self, chunk): if hasattr(chunk, "get"): - chunk = chunk.get("chunk") - chunk_data = json.loads(chunk.get("bytes").decode()) + chunk = chunk.get('chunk') + chunk_data = json.loads(chunk.get('bytes').decode()) else: chunk_data = json.loads(chunk.decode()) if chunk_data: - text = "" + text = "" is_finished = False finish_reason = "" - if "outputText" in chunk_data: - text = chunk_data["outputText"] + if "outputText" in chunk_data: + text = chunk_data['outputText'] # ai21 mapping - if "ai21" in self.model: # fake ai21 streaming - text = chunk_data.get("completions")[0].get("data").get("text") + if "ai21" in self.model: # fake ai21 streaming + text = chunk_data.get('completions')[0].get('data').get('text') is_finished = True finish_reason = "stop" # anthropic mapping - elif "completion" in chunk_data: - text = chunk_data["completion"] # bedrock.anthropic + elif "completion" in chunk_data: + text = chunk_data['completion'] # bedrock.anthropic stop_reason = chunk_data.get("stop_reason", None) if stop_reason != None: is_finished = True @@ -6720,26 +5856,22 @@ class CustomStreamWrapper: ######## bedrock.cohere mappings ############### # meta mapping elif "generation" in chunk_data: - text = chunk_data["generation"] # bedrock.meta + text = chunk_data['generation'] # bedrock.meta # cohere mapping elif "text" in chunk_data: - text = chunk_data["text"] # bedrock.cohere + text = chunk_data["text"] # bedrock.cohere # cohere mapping for finish reason elif "finish_reason" in chunk_data: finish_reason = chunk_data["finish_reason"] is_finished = True - elif chunk_data.get("completionReason", None): + elif chunk_data.get("completionReason", None): is_finished = True finish_reason = chunk_data["completionReason"] - elif chunk.get("error", None): + elif chunk.get("error", None): raise Exception(chunk["error"]) - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} return "" - + def chunk_creator(self, chunk): model_response = ModelResponse(stream=True, model=self.model) model_response.choices = [StreamingChoices()] @@ -6751,83 +5883,62 @@ class CustomStreamWrapper: if self.custom_llm_provider and self.custom_llm_provider == "anthropic": response_obj = self.handle_anthropic_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj[ - "finish_reason" - ] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj["finish_reason"] elif self.model == "replicate" or self.custom_llm_provider == "replicate": response_obj = self.handle_replicate_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj[ - "finish_reason" - ] - elif self.custom_llm_provider and self.custom_llm_provider == "together_ai": + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj["finish_reason"] + elif ( + self.custom_llm_provider and self.custom_llm_provider == "together_ai"): response_obj = self.handle_together_ai_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj[ - "finish_reason" - ] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider and self.custom_llm_provider == "huggingface": response_obj = self.handle_huggingface_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj[ - "finish_reason" - ] - elif ( - self.custom_llm_provider and self.custom_llm_provider == "baseten" - ): # baseten doesn't provide streaming + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider and self.custom_llm_provider == "baseten": # baseten doesn't provide streaming completion_obj["content"] = self.handle_baseten_chunk(chunk) - elif ( - self.custom_llm_provider and self.custom_llm_provider == "ai21" - ): # ai21 doesn't provide streaming + elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming response_obj = self.handle_ai21_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj[ - "finish_reason" - ] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider and self.custom_llm_provider == "maritalk": response_obj = self.handle_maritalk_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj[ - "finish_reason" - ] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider and self.custom_llm_provider == "vllm": completion_obj["content"] = chunk[0].outputs[0].text - elif ( - self.custom_llm_provider and self.custom_llm_provider == "aleph_alpha" - ): # aleph alpha doesn't provide streaming + elif self.custom_llm_provider and self.custom_llm_provider == "aleph_alpha": #aleph alpha doesn't provide streaming response_obj = self.handle_aleph_alpha_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj[ - "finish_reason" - ] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider == "nlp_cloud": - try: + try: response_obj = self.handle_nlp_cloud_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj[ - "finish_reason" - ] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj["finish_reason"] except Exception as e: if self.sent_last_chunk: raise e else: - if self.sent_first_chunk is False: + if self.sent_first_chunk is False: raise Exception("An unknown error occurred with the stream") model_response.choices[0].finish_reason = "stop" self.sent_last_chunk = True elif self.custom_llm_provider and self.custom_llm_provider == "vertex_ai": try: # print(chunk) - if hasattr(chunk, "text"): - # vertexAI chunks return + if hasattr(chunk, 'text'): + # vertexAI chunks return # MultiCandidateTextGenerationResponse(text=' ```python\n# This Python code says "Hi" 100 times.\n\n# Create', _prediction_response=Prediction(predictions=[{'candidates': [{'content': ' ```python\n# This Python code says "Hi" 100 times.\n\n# Create', 'author': '1'}], 'citationMetadata': [{'citations': None}], 'safetyAttributes': [{'blocked': False, 'scores': None, 'categories': None}]}], deployed_model_id='', model_version_id=None, model_resource_name=None, explanations=None), is_blocked=False, safety_attributes={}, candidates=[ ```python # This Python code says "Hi" 100 times. # Create]) @@ -6835,32 +5946,28 @@ class CustomStreamWrapper: else: completion_obj["content"] = str(chunk) except StopIteration as e: - if self.sent_last_chunk: - raise e + if self.sent_last_chunk: + raise e else: model_response.choices[0].finish_reason = "stop" self.sent_last_chunk = True elif self.custom_llm_provider == "cohere": response_obj = self.handle_cohere_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj[ - "finish_reason" - ] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider == "bedrock": - if self.sent_last_chunk: + if self.sent_last_chunk: raise StopIteration response_obj = self.handle_bedrock_stream(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj[ - "finish_reason" - ] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj["finish_reason"] self.sent_last_chunk = True elif self.custom_llm_provider == "sagemaker": print_verbose(f"ENTERS SAGEMAKER STREAMING") - if len(self.completion_stream) == 0: - if self.sent_last_chunk: + if len(self.completion_stream)==0: + if self.sent_last_chunk: raise StopIteration else: model_response.choices[0].finish_reason = "stop" @@ -6868,12 +5975,10 @@ class CustomStreamWrapper: new_chunk = self.completion_stream print_verbose(f"sagemaker chunk: {new_chunk}") completion_obj["content"] = new_chunk - self.completion_stream = self.completion_stream[ - len(self.completion_stream) : - ] + self.completion_stream = self.completion_stream[len(self.completion_stream):] elif self.custom_llm_provider == "petals": - if len(self.completion_stream) == 0: - if self.sent_last_chunk: + if len(self.completion_stream)==0: + if self.sent_last_chunk: raise StopIteration else: model_response.choices[0].finish_reason = "stop" @@ -6886,8 +5991,8 @@ class CustomStreamWrapper: elif self.custom_llm_provider == "palm": # fake streaming response_obj = {} - if len(self.completion_stream) == 0: - if self.sent_last_chunk: + if len(self.completion_stream)==0: + if self.sent_last_chunk: raise StopIteration else: model_response.choices[0].finish_reason = "stop" @@ -6901,50 +6006,33 @@ class CustomStreamWrapper: response_obj = self.handle_ollama_stream(chunk) completion_obj["content"] = response_obj["text"] print_verbose(f"completion obj content: {completion_obj['content']}") - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj[ - "finish_reason" - ] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider == "text-completion-openai": response_obj = self.handle_openai_text_completion_chunk(chunk) completion_obj["content"] = response_obj["text"] print_verbose(f"completion obj content: {completion_obj['content']}") - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj[ - "finish_reason" - ] - else: # openai chat model + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj["finish_reason"] + else: # openai chat model response_obj = self.handle_openai_chat_completion_chunk(chunk) if response_obj == None: return completion_obj["content"] = response_obj["text"] print_verbose(f"completion obj content: {completion_obj['content']}") - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj[ - "finish_reason" - ] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj["finish_reason"] model_response.model = self.model - print_verbose( - f"model_response: {model_response}; completion_obj: {completion_obj}" - ) - print_verbose( - f"model_response finish reason 3: {model_response.choices[0].finish_reason}" - ) - if ( - len(completion_obj["content"]) > 0 - ): # cannot set content of an OpenAI Object to be an empty string - hold, model_response_str = self.check_special_tokens( - chunk=completion_obj["content"], - finish_reason=model_response.choices[0].finish_reason, - ) # filter out bos/eos tokens from openai-compatible hf endpoints - print_verbose( - f"hold - {hold}, model_response_str - {model_response_str}" - ) - if hold is False: - ## check if openai/azure chunk + print_verbose(f"model_response: {model_response}; completion_obj: {completion_obj}") + print_verbose(f"model_response finish reason 3: {model_response.choices[0].finish_reason}") + if len(completion_obj["content"]) > 0: # cannot set content of an OpenAI Object to be an empty string + hold, model_response_str = self.check_special_tokens(chunk=completion_obj["content"], finish_reason=model_response.choices[0].finish_reason) # filter out bos/eos tokens from openai-compatible hf endpoints + print_verbose(f"hold - {hold}, model_response_str - {model_response_str}") + if hold is False: + ## check if openai/azure chunk original_chunk = response_obj.get("original_chunk", None) - if original_chunk: + if original_chunk: model_response.id = original_chunk.id if len(original_chunk.choices) > 0: try: @@ -6952,99 +6040,79 @@ class CustomStreamWrapper: model_response.choices[0].delta = Delta(**delta) except Exception as e: model_response.choices[0].delta = Delta() - else: - return - model_response.system_fingerprint = ( - original_chunk.system_fingerprint - ) + else: + return + model_response.system_fingerprint = original_chunk.system_fingerprint if self.sent_first_chunk == False: model_response.choices[0].delta["role"] = "assistant" self.sent_first_chunk = True - else: - ## else - completion_obj["content"] = model_response_str + else: + ## else + completion_obj["content"] = model_response_str if self.sent_first_chunk == False: completion_obj["role"] = "assistant" self.sent_first_chunk = True model_response.choices[0].delta = Delta(**completion_obj) print_verbose(f"model_response: {model_response}") return model_response - else: - return + else: + return elif model_response.choices[0].finish_reason: - # flush any remaining holding chunk + # flush any remaining holding chunk if len(self.holding_chunk) > 0: if model_response.choices[0].delta.content is None: model_response.choices[0].delta.content = self.holding_chunk else: - model_response.choices[0].delta.content = ( - self.holding_chunk + model_response.choices[0].delta.content - ) - self.holding_chunk = "" - model_response.choices[0].finish_reason = map_finish_reason( - model_response.choices[0].finish_reason - ) # ensure consistent output to openai + model_response.choices[0].delta.content = self.holding_chunk + model_response.choices[0].delta.content + self.holding_chunk = "" + model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai return model_response - elif ( - response_obj is not None - and response_obj.get("original_chunk", None) is not None - ): # function / tool calling branch - only set for openai/azure compatible endpoints + elif response_obj is not None and response_obj.get("original_chunk", None) is not None: # function / tool calling branch - only set for openai/azure compatible endpoints # enter this branch when no content has been passed in response original_chunk = response_obj.get("original_chunk", None) model_response.id = original_chunk.id if len(original_chunk.choices) > 0: - if ( - original_chunk.choices[0].delta.function_call is not None - or original_chunk.choices[0].delta.tool_calls is not None - ): + if original_chunk.choices[0].delta.function_call is not None or original_chunk.choices[0].delta.tool_calls is not None: try: delta = dict(original_chunk.choices[0].delta) model_response.choices[0].delta = Delta(**delta) except Exception as e: model_response.choices[0].delta = Delta() - else: + else: return - else: + else: return model_response.system_fingerprint = original_chunk.system_fingerprint if self.sent_first_chunk == False: model_response.choices[0].delta["role"] = "assistant" self.sent_first_chunk = True return model_response - else: + else: return except StopIteration: raise StopIteration - except Exception as e: + except Exception as e: traceback_exception = traceback.format_exc() e.message = str(e) - raise exception_type( - model=self.model, - custom_llm_provider=self.custom_llm_provider, - original_exception=e, - ) + raise exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=e) ## needs to handle the empty string case (even starting chunk can be an empty string) def __next__(self): try: while True: - if isinstance(self.completion_stream, str) or isinstance( - self.completion_stream, bytes - ): + if isinstance(self.completion_stream, str) or isinstance(self.completion_stream, bytes): chunk = self.completion_stream else: chunk = next(self.completion_stream) print_verbose(f"value of chunk: {chunk} ") - if chunk is not None and chunk != b"": + if chunk is not None and chunk != b'': print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}") response = self.chunk_creator(chunk=chunk) print_verbose(f"PROCESSED CHUNK POST CHUNK CREATOR: {response}") - if response is None: + if response is None: continue ## LOGGING - threading.Thread( - target=self.logging_obj.success_handler, args=(response,) - ).start() # log response + threading.Thread(target=self.logging_obj.success_handler, args=(response,)).start() # log response return response except StopIteration: raise # Re-raise StopIteration @@ -7052,59 +6120,43 @@ class CustomStreamWrapper: print_verbose(f"HITS AN ERROR: {str(e)}\n\n {traceback.format_exc()}") traceback_exception = traceback.format_exc() # LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated - threading.Thread( - target=self.logging_obj.failure_handler, args=(e, traceback_exception) - ).start() + threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start() raise e + + async def __anext__(self): try: - if ( - self.custom_llm_provider == "openai" + if (self.custom_llm_provider == "openai" or self.custom_llm_provider == "azure" or self.custom_llm_provider == "custom_openai" or self.custom_llm_provider == "text-completion-openai" or self.custom_llm_provider == "huggingface" or self.custom_llm_provider == "ollama" - or self.custom_llm_provider == "vertex_ai" - ): + or self.custom_llm_provider == "vertex_ai"): print_verbose(f"INSIDE ASYNC STREAMING!!!") - print_verbose( - f"value of async completion stream: {self.completion_stream}" - ) + print_verbose(f"value of async completion stream: {self.completion_stream}") async for chunk in self.completion_stream: print_verbose(f"value of async chunk: {chunk}") if chunk == "None" or chunk is None: raise Exception - # chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks. + # chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks. # __anext__ also calls async_success_handler, which does logging print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}") - processed_chunk = self.chunk_creator(chunk=chunk) - print_verbose( - f"PROCESSED ASYNC CHUNK POST CHUNK CREATOR: {processed_chunk}" - ) - if processed_chunk is None: + processed_chunk = self.chunk_creator(chunk=chunk) + print_verbose(f"PROCESSED ASYNC CHUNK POST CHUNK CREATOR: {processed_chunk}") + if processed_chunk is None: continue ## LOGGING - threading.Thread( - target=self.logging_obj.success_handler, args=(processed_chunk,) - ).start() # log response - asyncio.create_task( - self.logging_obj.async_success_handler( - processed_chunk, - ) - ) + threading.Thread(target=self.logging_obj.success_handler, args=(processed_chunk,)).start() # log response + asyncio.create_task(self.logging_obj.async_success_handler(processed_chunk,)) return processed_chunk raise StopAsyncIteration - else: # temporary patch for non-aiohttp async calls + else: # temporary patch for non-aiohttp async calls # example - boto3 bedrock llms processed_chunk = next(self) - asyncio.create_task( - self.logging_obj.async_success_handler( - processed_chunk, - ) - ) + asyncio.create_task(self.logging_obj.async_success_handler(processed_chunk,)) return processed_chunk except StopAsyncIteration: raise @@ -7113,12 +6165,9 @@ class CustomStreamWrapper: except Exception as e: traceback_exception = traceback.format_exc() # Handle any exceptions that might occur during streaming - asyncio.create_task( - self.logging_obj.async_failure_handler(e, traceback_exception) - ) + asyncio.create_task(self.logging_obj.async_failure_handler(e, traceback_exception)) raise StopAsyncIteration - class TextCompletionStreamWrapper: def __init__(self, completion_stream, model): self.completion_stream = completion_stream @@ -7129,18 +6178,16 @@ class TextCompletionStreamWrapper: def __aiter__(self): return self - + def convert_to_text_completion_object(self, chunk: ModelResponse): - try: + try: response = TextCompletionResponse() response["id"] = chunk.get("id", None) response["object"] = "text_completion" response["created"] = response.get("created", None) response["model"] = response.get("model", None) text_choices = TextChoices() - if isinstance( - chunk, Choices - ): # chunk should always be of type StreamingChoices + if isinstance(chunk, Choices): # chunk should always be of type StreamingChoices raise Exception text_choices["text"] = chunk["choices"][0]["delta"]["content"] text_choices["index"] = response["choices"][0]["index"] @@ -7148,9 +6195,7 @@ class TextCompletionStreamWrapper: response["choices"] = [text_choices] return response except Exception as e: - raise Exception( - f"Error occurred converting to text completion object - chunk: {chunk}; Error: {str(e)}" - ) + raise Exception(f"Error occurred converting to text completion object - chunk: {chunk}; Error: {str(e)}") def __next__(self): # model_response = ModelResponse(stream=True, model=self.model) @@ -7158,34 +6203,32 @@ class TextCompletionStreamWrapper: try: for chunk in self.completion_stream: if chunk == "None" or chunk is None: - raise Exception - processed_chunk = self.convert_to_text_completion_object(chunk=chunk) + raise Exception + processed_chunk = self.convert_to_text_completion_object(chunk=chunk) return processed_chunk raise StopIteration except StopIteration: raise StopIteration - except Exception as e: - print(f"got exception {e}") # noqa + except Exception as e: + print(f"got exception {e}") # noqa async def __anext__(self): try: async for chunk in self.completion_stream: if chunk == "None" or chunk is None: - raise Exception - processed_chunk = self.convert_to_text_completion_object(chunk=chunk) + raise Exception + processed_chunk = self.convert_to_text_completion_object(chunk=chunk) return processed_chunk raise StopIteration except StopIteration: raise StopAsyncIteration - def mock_completion_streaming_obj(model_response, mock_response, model): for i in range(0, len(mock_response), 3): - completion_obj = {"role": "assistant", "content": mock_response[i : i + 3]} + completion_obj = {"role": "assistant", "content": mock_response[i: i+3]} model_response.choices[0].delta = completion_obj yield model_response - ########## Reading Config File ############################ def read_config_args(config_path) -> dict: try: @@ -7200,25 +6243,23 @@ def read_config_args(config_path) -> dict: except Exception as e: raise e - ########## experimental completion variants ############################ - def completion_with_config(config: Union[dict, str], **kwargs): """ - Generate a litellm.completion() using a config dict and all supported completion args + Generate a litellm.completion() using a config dict and all supported completion args Example config; config = { "default_fallback_models": # [Optional] List of model names to try if a call fails - "available_models": # [Optional] List of all possible models you could call + "available_models": # [Optional] List of all possible models you could call "adapt_to_prompt_size": # [Optional] True/False - if you want to select model based on prompt size (will pick from available_models) "model": { "model-name": { - "needs_moderation": # [Optional] True/False - if you want to call openai moderations endpoint before making completion call. Will raise exception, if flagged. + "needs_moderation": # [Optional] True/False - if you want to call openai moderations endpoint before making completion call. Will raise exception, if flagged. "error_handling": { "error-type": { # One of the errors listed here - https://docs.litellm.ai/docs/exception_mapping#custom-mapping-list - "fallback_model": "" # str, name of the model it should try instead, when that error occurs + "fallback_model": "" # str, name of the model it should try instead, when that error occurs } } } @@ -7242,11 +6283,11 @@ def completion_with_config(config: Union[dict, str], **kwargs): raise Exception("Config path must be a string or a dictionary.") else: raise Exception("Config path not passed in.") - + if config is None: raise Exception("No completion config in the config file") - - models_with_config = config["model"].keys() + + models_with_config = config["model"].keys() model = kwargs["model"] messages = kwargs["messages"] @@ -7257,16 +6298,13 @@ def completion_with_config(config: Union[dict, str], **kwargs): trim_messages_flag = config.get("trim_messages", False) prompt_larger_than_model = False max_model = model - try: + try: max_tokens = litellm.get_max_tokens(model)["max_tokens"] except: - max_tokens = 2048 # assume curr model's max window is 2048 tokens + max_tokens = 2048 # assume curr model's max window is 2048 tokens if adapt_to_prompt_size: - ## Pick model based on token window - prompt_tokens = litellm.token_counter( - model="gpt-3.5-turbo", - text="".join(message["content"] for message in messages), - ) + ## Pick model based on token window + prompt_tokens = litellm.token_counter(model="gpt-3.5-turbo", text="".join(message["content"] for message in messages)) try: curr_max_tokens = litellm.get_max_tokens(model)["max_tokens"] except: @@ -7275,9 +6313,7 @@ def completion_with_config(config: Union[dict, str], **kwargs): prompt_larger_than_model = True for available_model in available_models: try: - curr_max_tokens = litellm.get_max_tokens(available_model)[ - "max_tokens" - ] + curr_max_tokens = litellm.get_max_tokens(available_model)["max_tokens"] if curr_max_tokens > max_tokens: max_tokens = curr_max_tokens max_model = available_model @@ -7291,16 +6327,16 @@ def completion_with_config(config: Union[dict, str], **kwargs): kwargs["messages"] = messages kwargs["model"] = model - try: - if model in models_with_config: + try: + if model in models_with_config: ## Moderation check if config["model"][model].get("needs_moderation"): input = " ".join(message["content"] for message in messages) response = litellm.moderation(input=input) flagged = response["results"][0]["flagged"] - if flagged: + if flagged: raise Exception("This response was flagged as inappropriate") - + ## Model-specific Error Handling error_handling = None if config["model"][model].get("error_handling"): @@ -7312,25 +6348,22 @@ def completion_with_config(config: Union[dict, str], **kwargs): except Exception as e: exception_name = type(e).__name__ fallback_model = None - if error_handling and exception_name in error_handling: + if error_handling and exception_name in error_handling: error_handler = error_handling[exception_name] - # either switch model or api key + # either switch model or api key fallback_model = error_handler.get("fallback_model", None) - if fallback_model: + if fallback_model: kwargs["model"] = fallback_model return litellm.completion(**kwargs) raise e - else: + else: return litellm.completion(**kwargs) except Exception as e: if fallback_models: model = fallback_models.pop(0) - return completion_with_fallbacks( - model=model, messages=messages, fallbacks=fallback_models - ) + return completion_with_fallbacks(model=model, messages=messages, fallbacks=fallback_models) raise e - def completion_with_fallbacks(**kwargs): nested_kwargs = kwargs.pop("kwargs", {}) response = None @@ -7348,10 +6381,8 @@ def completion_with_fallbacks(**kwargs): for model in fallbacks: # loop thru all models try: - # check if it's dict or new model string - if isinstance( - model, dict - ): # completion(model="gpt-4", fallbacks=[{"api_key": "", "api_base": ""}, {"api_key": "", "api_base": ""}]) + # check if it's dict or new model string + if isinstance(model, dict): # completion(model="gpt-4", fallbacks=[{"api_key": "", "api_base": ""}, {"api_key": "", "api_base": ""}]) kwargs["api_key"] = model.get("api_key", None) kwargs["api_base"] = model.get("api_base", None) model = model.get("model", original_model) @@ -7374,10 +6405,7 @@ def completion_with_fallbacks(**kwargs): print_verbose(f"trying to make completion call with model: {model}") kwargs["litellm_call_id"] = litellm_call_id - kwargs = { - **kwargs, - **nested_kwargs, - } # combine the openai + litellm params at the same level + kwargs = {**kwargs, **nested_kwargs} # combine the openai + litellm params at the same level response = litellm.completion(**kwargs, model=model) print_verbose(f"response: {response}") if response != None: @@ -7392,24 +6420,18 @@ def completion_with_fallbacks(**kwargs): pass return response - def process_system_message(system_message, max_tokens, model): system_message_event = {"role": "system", "content": system_message} system_message_tokens = get_token_count([system_message_event], model) if system_message_tokens > max_tokens: - print_verbose( - "`tokentrimmer`: Warning, system message exceeds token limit. Trimming..." - ) + print_verbose("`tokentrimmer`: Warning, system message exceeds token limit. Trimming...") # shorten system message to fit within max_tokens - new_system_message = shorten_message_to_fit_limit( - system_message_event, max_tokens, model - ) + new_system_message = shorten_message_to_fit_limit(system_message_event, max_tokens, model) system_message_tokens = get_token_count([new_system_message], model) - + return system_message_event, max_tokens - system_message_tokens - def process_messages(messages, max_tokens, model): # Process messages from older to more recent messages = messages[::-1] @@ -7420,26 +6442,17 @@ def process_messages(messages, max_tokens, model): available_tokens = max_tokens - used_tokens if available_tokens <= 3: break - final_messages = attempt_message_addition( - final_messages=final_messages, - message=message, - available_tokens=available_tokens, - max_tokens=max_tokens, - model=model, - ) + final_messages = attempt_message_addition(final_messages=final_messages, message=message, available_tokens=available_tokens, max_tokens=max_tokens, model=model) return final_messages - -def attempt_message_addition( - final_messages, message, available_tokens, max_tokens, model -): +def attempt_message_addition(final_messages, message, available_tokens, max_tokens, model): temp_messages = [message] + final_messages temp_message_tokens = get_token_count(messages=temp_messages, model=model) if temp_message_tokens <= max_tokens: return temp_messages - + # if temp_message_tokens > max_tokens, try shortening temp_messages elif "function_call" not in message: # fit updated_message to be within temp_message_tokens - max_tokens (aka the amount temp_message_tokens is greate than max_tokens) @@ -7449,18 +6462,19 @@ def attempt_message_addition( return final_messages - def can_add_message(message, messages, max_tokens, model): if get_token_count(messages + [message], model) <= max_tokens: return True return False - def get_token_count(messages, model): return token_counter(model=model, messages=messages) -def shorten_message_to_fit_limit(message, tokens_needed, model): +def shorten_message_to_fit_limit( + message, + tokens_needed, + model): """ Shorten a message to fit within a token limit by removing characters from the middle. """ @@ -7468,7 +6482,7 @@ def shorten_message_to_fit_limit(message, tokens_needed, model): # For OpenAI models, even blank messages cost 7 token, # and if the buffer is less than 3, the while loop will never end, # hence the value 10. - if "gpt" in model and tokens_needed <= 10: + if 'gpt' in model and tokens_needed <= 10: return message content = message["content"] @@ -7480,22 +6494,21 @@ def shorten_message_to_fit_limit(message, tokens_needed, model): break ratio = (tokens_needed) / total_tokens - - new_length = int(len(content) * ratio) - 1 + + new_length = int(len(content) * ratio) -1 new_length = max(0, new_length) half_length = new_length // 2 left_half = content[:half_length] right_half = content[-half_length:] - trimmed_content = left_half + ".." + right_half + trimmed_content = left_half + '..' + right_half message["content"] = trimmed_content content = trimmed_content return message - -# LiteLLM token trimmer +# LiteLLM token trimmer # this code is borrowed from https://github.com/KillianLucas/tokentrim/blob/main/tokentrim/tokentrim.py # Credits for this code go to Killian Lucas def trim_messages( @@ -7503,8 +6516,8 @@ def trim_messages( model: Optional[str] = None, trim_ratio: float = 0.75, return_response_tokens: bool = False, - max_tokens=None, -): + max_tokens = None + ): """ Trim a list of messages to fit within a model's token limit. @@ -7526,18 +6539,18 @@ def trim_messages( if max_tokens == None: # Check if model is valid if model in litellm.model_cost: - max_tokens_for_model = litellm.model_cost[model]["max_tokens"] + max_tokens_for_model = litellm.model_cost[model]['max_tokens'] max_tokens = int(max_tokens_for_model * trim_ratio) else: - # if user did not specify max tokens + # if user did not specify max tokens # or passed an llm litellm does not know # do nothing, just return messages - return - - system_message = "" + return + + system_message = "" for message in messages: if message["role"] == "system": - system_message += "\n" if system_message else "" + system_message += '\n' if system_message else '' system_message += message["content"] current_tokens = token_counter(model=model, messages=messages) @@ -7545,47 +6558,38 @@ def trim_messages( # Do nothing if current tokens under messages if current_tokens < max_tokens: - return messages - + return messages + #### Trimming messages if current_tokens > max_tokens - print_verbose( - f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}" - ) + print_verbose(f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}") if system_message: - system_message_event, max_tokens = process_system_message( - system_message=system_message, max_tokens=max_tokens, model=model - ) + system_message_event, max_tokens = process_system_message(system_message=system_message, max_tokens=max_tokens, model=model) - if max_tokens == 0: # the system messages are too long + if max_tokens == 0: # the system messages are too long return [system_message_event] - - # Since all system messages are combined and trimmed to fit the max_tokens, + + # Since all system messages are combined and trimmed to fit the max_tokens, # we remove all system messages from the messages list messages = [message for message in messages if message["role"] != "system"] - final_messages = process_messages( - messages=messages, max_tokens=max_tokens, model=model - ) + final_messages = process_messages(messages=messages, max_tokens=max_tokens, model=model) # Add system message to the beginning of the final messages if system_message: final_messages = [system_message_event] + final_messages - if ( - return_response_tokens - ): # if user wants token count with new trimmed messages + if return_response_tokens: # if user wants token count with new trimmed messages response_tokens = max_tokens - get_token_count(final_messages, model) return final_messages, response_tokens return final_messages - except Exception as e: # [NON-Blocking, if error occurs just return final_messages + except Exception as e: # [NON-Blocking, if error occurs just return final_messages print_verbose(f"Got exception while token trimming{e}") return messages - def get_valid_models(): """ Returns a list of valid LLMs based on the set environment variables - + Args: None @@ -7603,13 +6607,13 @@ def get_valid_models(): # edge case litellm has together_ai as a provider, it should be togetherai provider = provider.replace("_", "") - # litellm standardizes expected provider keys to + # litellm standardizes expected provider keys to # PROVIDER_API_KEY. Example: OPENAI_API_KEY, COHERE_API_KEY expected_provider_key = f"{provider.upper()}_API_KEY" - if expected_provider_key in environ_keys: - # key is set + if expected_provider_key in environ_keys: + # key is set valid_providers.append(provider) - + for provider in valid_providers: if provider == "azure": valid_models.append("Azure-LLM") @@ -7618,8 +6622,7 @@ def get_valid_models(): valid_models.extend(models_for_provider) return valid_models except: - return [] # NON-Blocking - + return [] # NON-Blocking # used for litellm.text_completion() to transform HF logprobs to OpenAI.Completion() format def transform_logprobs(hf_response): @@ -7629,39 +6632,40 @@ def transform_logprobs(hf_response): # For each Hugging Face response, transform the logprobs for response in hf_response: # Extract the relevant information from the response - response_details = response["details"] + response_details = response['details'] top_tokens = response_details.get("top_tokens", {}) # Initialize an empty list for the token information token_info = { - "tokens": [], - "token_logprobs": [], - "text_offset": [], - "top_logprobs": [], + 'tokens': [], + 'token_logprobs': [], + 'text_offset': [], + 'top_logprobs': [], } - for i, token in enumerate(response_details["prefill"]): + for i, token in enumerate(response_details['prefill']): # Extract the text of the token - token_text = token["text"] + token_text = token['text'] # Extract the logprob of the token - token_logprob = token["logprob"] + token_logprob = token['logprob'] # Add the token information to the 'token_info' list - token_info["tokens"].append(token_text) - token_info["token_logprobs"].append(token_logprob) + token_info['tokens'].append(token_text) + token_info['token_logprobs'].append(token_logprob) # stub this to work with llm eval harness - top_alt_tokens = {"": -1, "": -2, "": -3} - token_info["top_logprobs"].append(top_alt_tokens) + top_alt_tokens = { "": -1, "": -2, "": -3 } + token_info['top_logprobs'].append(top_alt_tokens) # For each element in the 'tokens' list, extract the relevant information - for i, token in enumerate(response_details["tokens"]): + for i, token in enumerate(response_details['tokens']): + # Extract the text of the token - token_text = token["text"] + token_text = token['text'] # Extract the logprob of the token - token_logprob = token["logprob"] + token_logprob = token['logprob'] top_alt_tokens = {} temp_top_logprobs = [] @@ -7675,15 +6679,13 @@ def transform_logprobs(hf_response): top_alt_tokens[text] = logprob # Add the token information to the 'token_info' list - token_info["tokens"].append(token_text) - token_info["token_logprobs"].append(token_logprob) - token_info["top_logprobs"].append(top_alt_tokens) + token_info['tokens'].append(token_text) + token_info['token_logprobs'].append(token_logprob) + token_info['top_logprobs'].append(top_alt_tokens) # Add the text offset of the token # This is computed as the sum of the lengths of all previous tokens - token_info["text_offset"].append( - sum(len(t["text"]) for t in response_details["tokens"][:i]) - ) + token_info['text_offset'].append(sum(len(t['text']) for t in response_details['tokens'][:i])) # Add the 'token_info' list to the 'transformed_logprobs' list transformed_logprobs = token_info