diff --git a/litellm/integrations/langfuse.py b/litellm/integrations/langfuse.py index 7aab330b3..41962add0 100644 --- a/litellm/integrations/langfuse.py +++ b/litellm/integrations/langfuse.py @@ -33,7 +33,9 @@ class LangFuseLogger: debug=True, ) - def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): + def log_event( + self, kwargs, response_obj, start_time, end_time, user_id, print_verbose + ): # Method definition try: @@ -64,6 +66,7 @@ class LangFuseLogger: output = response_obj["choices"][0]["message"].json() self._log_langfuse_v2( + user_id, metadata, output, start_time, @@ -73,6 +76,7 @@ class LangFuseLogger: input, response_obj, ) if self._is_langfuse_v2() else self._log_langfuse_v1( + user_id, metadata, output, start_time, @@ -93,9 +97,11 @@ class LangFuseLogger: pass async def _async_log_event( - self, kwargs, response_obj, start_time, end_time, print_verbose + self, kwargs, response_obj, start_time, end_time, user_id, print_verbose ): - self.log_event(kwargs, response_obj, start_time, end_time, print_verbose) + self.log_event( + kwargs, response_obj, start_time, end_time, user_id, print_verbose + ) def _is_langfuse_v2(self): import langfuse @@ -104,6 +110,7 @@ class LangFuseLogger: def _log_langfuse_v1( self, + user_id, metadata, output, start_time, @@ -120,6 +127,7 @@ class LangFuseLogger: name=metadata.get("generation_name", "litellm-completion"), input=input, output=output, + userId=user_id, ) ) @@ -142,6 +150,7 @@ class LangFuseLogger: def _log_langfuse_v2( self, + user_id, metadata, output, start_time, @@ -155,6 +164,7 @@ class LangFuseLogger: name=metadata.get("generation_name", "litellm-completion"), input=input, output=output, + user_id=user_id, ) trace.generation( diff --git a/litellm/tests/langfuse.log b/litellm/tests/langfuse.log index 597262903..58c1c8fb2 100644 --- a/litellm/tests/langfuse.log +++ b/litellm/tests/langfuse.log @@ -1,37 +1,2 @@ -Using selector: KqueueSelector -consumer is running... -Starting new HTTPS connection (1): litellm-logging.onrender.com:443 -Request options: {'method': 'post', 'url': '/chat/completions', 'files': None, 'json_data': {'messages': [{'role': 'user', 'content': 'This is a test'}], 'model': 'gpt-3.5-turbo', 'max_tokens': 100, 'temperature': 0.7}} -connect_tcp.started host='api.openai.com' port=443 local_address=None timeout=5.0 socket_options=None -connect_tcp.complete return_value= -start_tls.started ssl_context= server_hostname='api.openai.com' timeout=5.0 -start_tls.complete return_value= -send_request_headers.started request= -send_request_headers.complete -send_request_body.started request= -send_request_body.complete -receive_response_headers.started request= -https://litellm-logging.onrender.com:443 "POST /logging HTTP/1.1" 200 38 -receive_response_headers.complete return_value=(b'HTTP/1.1', 200, b'OK', [(b'Date', b'Mon, 18 Dec 2023 21:53:36 GMT'), (b'Content-Type', b'application/json'), (b'Transfer-Encoding', b'chunked'), (b'Connection', b'keep-alive'), (b'access-control-allow-origin', b'*'), (b'Cache-Control', b'no-cache, must-revalidate'), (b'openai-model', b'gpt-3.5-turbo-0613'), (b'openai-organization', b'finto-technologies'), (b'openai-processing-ms', b'314'), (b'openai-version', b'2020-10-01'), (b'strict-transport-security', b'max-age=15724800; includeSubDomains'), (b'x-ratelimit-limit-requests', b'5000'), (b'x-ratelimit-limit-tokens', b'160000'), (b'x-ratelimit-limit-tokens_usage_based', b'160000'), (b'x-ratelimit-remaining-requests', b'4999'), (b'x-ratelimit-remaining-tokens', b'159895'), (b'x-ratelimit-remaining-tokens_usage_based', b'159895'), (b'x-ratelimit-reset-requests', b'12ms'), (b'x-ratelimit-reset-tokens', b'39ms'), (b'x-ratelimit-reset-tokens_usage_based', b'39ms'), (b'x-request-id', b'798c68979c33c09835370164b9c3a523'), (b'CF-Cache-Status', b'DYNAMIC'), (b'Set-Cookie', b'__cf_bm=CbrXQ9eH3xFyKA4RzW3z3LlpLb_1pGPWeFTYPtWcE50-1702936416-1-ASb/OMcdGX68dHUk+/wA7xDISru2gTUlUJCwGntKnQ58aBvxa5I6ws5xiY6cXyT8hm9s5bX09Q4Tdb/b85w3rFs=; path=/; expires=Mon, 18-Dec-23 22:23:36 GMT; domain=.api.openai.com; HttpOnly; Secure; SameSite=None'), (b'Set-Cookie', b'_cfuvid=hNImRsjGg2JqU2MW6VYVAAMPGT99ADf9XOKBz5pJix0-1702936416944-0-604800000; path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None'), (b'Server', b'cloudflare'), (b'CF-RAY', b'837aa3b94aa51655-WAW'), (b'Content-Encoding', b'gzip'), (b'alt-svc', b'h3=":443"; ma=86400')]) -receive_response_body.started request= -receive_response_body.complete -response_closed.started -response_closed.complete -HTTP Request: POST https://api.openai.com/v1/chat/completions "200 OK" -Creating trace id='5b723fd6-0b1b-4c0a-b254-29a6e1fe29ed' name='litellm-completion' user_id=None input=[[{'role': 'user', 'content': 'This is a test'}]] output={'content': 'Great! What would you like to test?', 'role': 'assistant'} session_id=None release=None version=None metadata=None public=None -adding task {'id': '3ce30ace-129e-4a4d-b9db-ed42cdfc5bc5', 'type': 'trace-create', 'body': {'id': '5b723fd6-0b1b-4c0a-b254-29a6e1fe29ed', 'name': 'litellm-completion', 'input': [[{'role': 'user', 'content': 'This is a test'}]], 'output': {'content': 'Great! What would you like to test?', 'role': 'assistant'}}} -Creating generation trace_id='5b723fd6-0b1b-4c0a-b254-29a6e1fe29ed' name='litellm-completion' start_time=datetime.datetime(2023, 12, 18, 22, 53, 35, 556108) metadata={} input=[[{'role': 'user', 'content': 'This is a test'}]] output={'content': 'Great! What would you like to test?', 'role': 'assistant'} level=None status_message=None parent_observation_id=None version=None id='215b1635-46e3-4791-878b-6d76213b8559' end_time=datetime.datetime(2023, 12, 18, 22, 53, 36, 522751) completion_start_time=None model='gpt-3.5-turbo' model_parameters={'temperature': '0.7', 'max_tokens': 100} usage=Usage(input=11, output=9, total=None, unit=)... -item size 348 -adding task {'id': '361e8f67-f46f-42ce-bf9b-e5aab7c5aa38', 'type': 'generation-create', 'body': {'traceId': '5b723fd6-0b1b-4c0a-b254-29a6e1fe29ed', 'name': 'litellm-completion', 'startTime': datetime.datetime(2023, 12, 18, 22, 53, 35, 556108), 'metadata': {}, 'input': [[{'role': 'user', 'content': 'This is a test'}]], 'output': {'content': 'Great! What would you like to test?', 'role': 'assistant'}, 'id': '215b1635-46e3-4791-878b-6d76213b8559', 'endTime': datetime.datetime(2023, 12, 18, 22, 53, 36, 522751), 'model': 'gpt-3.5-turbo', 'modelParameters': {'temperature': '0.7', 'max_tokens': 100}, 'usage': {'input': 11, 'output': 9, 'unit': }}} -flushing queue -item size 659 -uploading batch of 2 items -uploading data: {'batch': [{'id': '3ce30ace-129e-4a4d-b9db-ed42cdfc5bc5', 'type': 'trace-create', 'body': {'id': '5b723fd6-0b1b-4c0a-b254-29a6e1fe29ed', 'name': 'litellm-completion', 'input': [[{'role': 'user', 'content': 'This is a test'}]], 'output': {'content': 'Great! What would you like to test?', 'role': 'assistant'}}, 'timestamp': datetime.datetime(2023, 12, 18, 21, 53, 36, 524507, tzinfo=tzutc())}, {'id': '361e8f67-f46f-42ce-bf9b-e5aab7c5aa38', 'type': 'generation-create', 'body': {'traceId': '5b723fd6-0b1b-4c0a-b254-29a6e1fe29ed', 'name': 'litellm-completion', 'startTime': datetime.datetime(2023, 12, 18, 22, 53, 35, 556108), 'metadata': {}, 'input': [[{'role': 'user', 'content': 'This is a test'}]], 'output': {'content': 'Great! What would you like to test?', 'role': 'assistant'}, 'id': '215b1635-46e3-4791-878b-6d76213b8559', 'endTime': datetime.datetime(2023, 12, 18, 22, 53, 36, 522751), 'model': 'gpt-3.5-turbo', 'modelParameters': {'temperature': '0.7', 'max_tokens': 100}, 'usage': {'input': 11, 'output': 9, 'unit': }}, 'timestamp': datetime.datetime(2023, 12, 18, 21, 53, 36, 525388, tzinfo=tzutc())}], 'metadata': {'batch_size': 2, 'sdk_integration': 'default', 'sdk_name': 'python', 'sdk_version': '2.0.1', 'public_key': 'pk-lf-1234567890'}} -making request: {"batch": [{"id": "3ce30ace-129e-4a4d-b9db-ed42cdfc5bc5", "type": "trace-create", "body": {"id": "5b723fd6-0b1b-4c0a-b254-29a6e1fe29ed", "name": "litellm-completion", "input": [[{"role": "user", "content": "This is a test"}]], "output": {"content": "Great! What would you like to test?", "role": "assistant"}}, "timestamp": "2023-12-18T21:53:36.524507+00:00"}, {"id": "361e8f67-f46f-42ce-bf9b-e5aab7c5aa38", "type": "generation-create", "body": {"traceId": "5b723fd6-0b1b-4c0a-b254-29a6e1fe29ed", "name": "litellm-completion", "startTime": "2023-12-18T22:53:35.556108+00:00", "metadata": {}, "input": [[{"role": "user", "content": "This is a test"}]], "output": {"content": "Great! What would you like to test?", "role": "assistant"}, "id": "215b1635-46e3-4791-878b-6d76213b8559", "endTime": "2023-12-18T22:53:36.522751+00:00", "model": "gpt-3.5-turbo", "modelParameters": {"temperature": "0.7", "max_tokens": 100}, "usage": {"input": 11, "output": 9, "unit": "TOKENS"}}, "timestamp": "2023-12-18T21:53:36.525388+00:00"}], "metadata": {"batch_size": 2, "sdk_integration": "default", "sdk_name": "python", "sdk_version": "2.0.1", "public_key": "pk-lf-1234567890"}} to http://localhost:3000/api/public/ingestion -Starting new HTTP connection (1): localhost:3000 -http://localhost:3000 "POST /api/public/ingestion HTTP/1.1" 207 145 -received response: {"errors":[],"successes":[{"id":"3ce30ace-129e-4a4d-b9db-ed42cdfc5bc5","status":201},{"id":"361e8f67-f46f-42ce-bf9b-e5aab7c5aa38","status":201}]} -successfully uploaded batch of 2 items -successfully flushed about 0 items. -joining 1 consumer threads -consumer thread 0 joined +close.started +close.complete diff --git a/litellm/tests/test_langfuse.py b/litellm/tests/test_langfuse.py index 5b9e82410..f77d2acdc 100644 --- a/litellm/tests/test_langfuse.py +++ b/litellm/tests/test_langfuse.py @@ -105,6 +105,7 @@ def test_langfuse_logging_async(): max_tokens=100, temperature=0.7, timeout=5, + user="test_user", ) response = asyncio.run(_test_langfuse()) @@ -198,7 +199,7 @@ def test_langfuse_logging_custom_generation_name(): print(e) -test_langfuse_logging_custom_generation_name() +# test_langfuse_logging_custom_generation_name() @pytest.mark.skip(reason="beta test - checking langfuse output") @@ -235,7 +236,7 @@ def test_langfuse_logging_function_calling(): print(e) -test_langfuse_logging_function_calling() +# test_langfuse_logging_function_calling() def test_langfuse_logging_tool_calling(): @@ -296,4 +297,4 @@ def test_langfuse_logging_tool_calling(): tool_calls = response.choices[0].message.tool_calls -test_langfuse_logging_tool_calling() +# test_langfuse_logging_tool_calling() diff --git a/litellm/utils.py b/litellm/utils.py index 45d5d02f0..4c1cb2efd 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -27,6 +27,7 @@ from dataclasses import ( dataclass, field, ) # for storing API inputs, outputs, and metadata + encoding = tiktoken.get_encoding("cl100k_base") import importlib.metadata from .integrations.traceloop import TraceloopLogger @@ -56,16 +57,17 @@ from .exceptions import ( APIConnectionError, APIError, BudgetExceededError, - UnprocessableEntityError + UnprocessableEntityError, ) from typing import cast, List, Dict, Union, Optional, Literal from .caching import Cache from concurrent.futures import ThreadPoolExecutor + ####### ENVIRONMENT VARIABLES #################### # Adjust to your specific application needs / system capabilities. -MAX_THREADS = 100 +MAX_THREADS = 100 -# Create a ThreadPoolExecutor +# Create a ThreadPoolExecutor executor = ThreadPoolExecutor(max_workers=MAX_THREADS) dotenv.load_dotenv() # Loading env variables using dotenv sentry_sdk_instance = None @@ -111,6 +113,7 @@ last_fetched_at_keys = None # 'usage': {'prompt_tokens': 18, 'completion_tokens': 23, 'total_tokens': 41} # } + class UnsupportedParamsError(Exception): def __init__(self, status_code, message): self.status_code = status_code @@ -122,64 +125,81 @@ class UnsupportedParamsError(Exception): ) # Call the base class constructor with the parameters it needs -def _generate_id(): # private helper function - return 'chatcmpl-' + str(uuid.uuid4()) +def _generate_id(): # private helper function + return "chatcmpl-" + str(uuid.uuid4()) -def map_finish_reason(finish_reason: str): # openai supports 5 stop sequences - 'stop', 'length', 'function_call', 'content_filter', 'null' + +def map_finish_reason( + finish_reason: str, +): # openai supports 5 stop sequences - 'stop', 'length', 'function_call', 'content_filter', 'null' # anthropic mapping if finish_reason == "stop_sequence": return "stop" # cohere mapping - https://docs.cohere.com/reference/generate - elif finish_reason == "COMPLETE": + elif finish_reason == "COMPLETE": return "stop" - elif finish_reason == "MAX_TOKENS": # cohere + vertex ai + elif finish_reason == "MAX_TOKENS": # cohere + vertex ai return "length" - elif finish_reason == "ERROR_TOXIC": + elif finish_reason == "ERROR_TOXIC": return "content_filter" - elif finish_reason == "ERROR": # openai currently doesn't support an 'error' finish reason + elif ( + finish_reason == "ERROR" + ): # openai currently doesn't support an 'error' finish reason return "stop" # huggingface mapping https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/generate_stream elif finish_reason == "eos_token" or finish_reason == "stop_sequence": return "stop" - elif finish_reason == "FINISH_REASON_UNSPECIFIED" or finish_reason == "STOP": # vertex ai - got from running `print(dir(response_obj.candidates[0].finish_reason))`: ['FINISH_REASON_UNSPECIFIED', 'MAX_TOKENS', 'OTHER', 'RECITATION', 'SAFETY', 'STOP',] + elif ( + finish_reason == "FINISH_REASON_UNSPECIFIED" or finish_reason == "STOP" + ): # vertex ai - got from running `print(dir(response_obj.candidates[0].finish_reason))`: ['FINISH_REASON_UNSPECIFIED', 'MAX_TOKENS', 'OTHER', 'RECITATION', 'SAFETY', 'STOP',] return "stop" - elif finish_reason == "SAFETY": # vertex ai + elif finish_reason == "SAFETY": # vertex ai return "content_filter" return finish_reason + class FunctionCall(OpenAIObject): arguments: str name: str + class Function(OpenAIObject): arguments: str name: str + class ChatCompletionMessageToolCall(OpenAIObject): id: str function: Function type: str + class Message(OpenAIObject): - def __init__(self, content="default", role="assistant", logprobs=None, function_call=None, tool_calls=None, **params): + def __init__( + self, + content="default", + role="assistant", + logprobs=None, + function_call=None, + tool_calls=None, + **params, + ): super(Message, self).__init__(**params) self.content = content self.role = role - if function_call is not None: + if function_call is not None: self.function_call = FunctionCall(**function_call) if tool_calls is not None: self.tool_calls = [] for tool_call in tool_calls: - self.tool_calls.append( - ChatCompletionMessageToolCall(**tool_call) - ) + self.tool_calls.append(ChatCompletionMessageToolCall(**tool_call)) if logprobs is not None: - self._logprobs = logprobs + self._logprobs = logprobs def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -190,7 +210,7 @@ class Message(OpenAIObject): def json(self, **kwargs): try: - return self.model_dump() # noqa + return self.model_dump() # noqa except: # if using pydantic v1 return self.dict() @@ -201,7 +221,7 @@ class Delta(OpenAIObject): super(Delta, self).__init__(**params) self.content = content self.role = role - + def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) @@ -209,7 +229,7 @@ class Delta(OpenAIObject): def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -222,13 +242,15 @@ class Delta(OpenAIObject): class Choices(OpenAIObject): def __init__(self, finish_reason=None, index=0, message=None, **params): super(Choices, self).__init__(**params) - self.finish_reason = map_finish_reason(finish_reason) or "stop" # set finish_reason for all responses + self.finish_reason = ( + map_finish_reason(finish_reason) or "stop" + ) # set finish_reason for all responses self.index = index if message is None: self.message = Message(content=None) else: self.message = message - + def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) @@ -236,7 +258,7 @@ class Choices(OpenAIObject): def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -245,8 +267,11 @@ class Choices(OpenAIObject): # Allow dictionary-style assignment of attributes setattr(self, key, value) + class Usage(OpenAIObject): - def __init__(self, prompt_tokens=None, completion_tokens=None, total_tokens=None, **params): + def __init__( + self, prompt_tokens=None, completion_tokens=None, total_tokens=None, **params + ): super(Usage, self).__init__(**params) if prompt_tokens: self.prompt_tokens = prompt_tokens @@ -254,15 +279,15 @@ class Usage(OpenAIObject): self.completion_tokens = completion_tokens if total_tokens: self.total_tokens = total_tokens - + def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) - + def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -271,8 +296,11 @@ class Usage(OpenAIObject): # Allow dictionary-style assignment of attributes setattr(self, key, value) + class StreamingChoices(OpenAIObject): - def __init__(self, finish_reason=None, index=0, delta: Optional[Delta]=None, **params): + def __init__( + self, finish_reason=None, index=0, delta: Optional[Delta] = None, **params + ): super(StreamingChoices, self).__init__(**params) if finish_reason: self.finish_reason = finish_reason @@ -283,15 +311,15 @@ class StreamingChoices(OpenAIObject): self.delta = delta else: self.delta = Delta() - + def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) - + def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -300,7 +328,8 @@ class StreamingChoices(OpenAIObject): # Allow dictionary-style assignment of attributes setattr(self, key, value) -class ModelResponse(OpenAIObject): + +class ModelResponse(OpenAIObject): id: str """A unique identifier for the completion.""" @@ -328,7 +357,20 @@ class ModelResponse(OpenAIObject): _hidden_params: dict = {} - def __init__(self, id=None, choices=None, created=None, model=None, object=None, system_fingerprint=None, usage=None, stream=False, response_ms=None, hidden_params=None, **params): + def __init__( + self, + id=None, + choices=None, + created=None, + model=None, + object=None, + system_fingerprint=None, + usage=None, + stream=False, + response_ms=None, + hidden_params=None, + **params, + ): if stream: object = "chat.completion.chunk" choices = [StreamingChoices()] @@ -353,16 +395,25 @@ class ModelResponse(OpenAIObject): usage = Usage() if hidden_params: self._hidden_params = hidden_params - super().__init__(id=id, choices=choices, created=created, model=model, object=object, system_fingerprint=system_fingerprint, usage=usage, **params) - + super().__init__( + id=id, + choices=choices, + created=created, + model=model, + object=object, + system_fingerprint=system_fingerprint, + usage=usage, + **params, + ) + def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) - + def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -370,14 +421,15 @@ class ModelResponse(OpenAIObject): def __setitem__(self, key, value): # Allow dictionary-style assignment of attributes setattr(self, key, value) - + def json(self, **kwargs): try: - return self.model_dump() # noqa + return self.model_dump() # noqa except: # if using pydantic v1 return self.dict() + class Embedding(OpenAIObject): embedding: list = [] index: int @@ -386,7 +438,7 @@ class Embedding(OpenAIObject): def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -395,6 +447,7 @@ class Embedding(OpenAIObject): # Allow dictionary-style assignment of attributes setattr(self, key, value) + class EmbeddingResponse(OpenAIObject): model: Optional[str] = None """The model used for embedding.""" @@ -408,17 +461,19 @@ class EmbeddingResponse(OpenAIObject): usage: Optional[Usage] = None """Usage statistics for the embedding request.""" - def __init__(self, model=None, usage=None, stream=False, response_ms=None, data=None): + def __init__( + self, model=None, usage=None, stream=False, response_ms=None, data=None + ): object = "list" if response_ms: _response_ms = response_ms else: _response_ms = None - if data: + if data: data = data - else: + else: data = None - + if usage: usage = usage else: @@ -430,11 +485,11 @@ class EmbeddingResponse(OpenAIObject): def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) - + def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -442,14 +497,15 @@ class EmbeddingResponse(OpenAIObject): def __setitem__(self, key, value): # Allow dictionary-style assignment of attributes setattr(self, key, value) - + def json(self, **kwargs): try: - return self.model_dump() # noqa + return self.model_dump() # noqa except: # if using pydantic v1 return self.dict() + class TextChoices(OpenAIObject): def __init__(self, finish_reason=None, index=0, text=None, logprobs=None, **params): super(TextChoices, self).__init__(**params) @@ -466,7 +522,7 @@ class TextChoices(OpenAIObject): self.logprobs = [] else: self.logprobs = logprobs - + def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) @@ -474,7 +530,7 @@ class TextChoices(OpenAIObject): def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -483,6 +539,7 @@ class TextChoices(OpenAIObject): # Allow dictionary-style assignment of attributes setattr(self, key, value) + class TextCompletionResponse(OpenAIObject): """ { @@ -501,7 +558,18 @@ class TextCompletionResponse(OpenAIObject): "usage": response["usage"] } """ - def __init__(self, id=None, choices=None, created=None, model=None, usage=None, stream=False, response_ms=None, **params): + + def __init__( + self, + id=None, + choices=None, + created=None, + model=None, + usage=None, + stream=False, + response_ms=None, + **params, + ): super(TextCompletionResponse, self).__init__(**params) if stream: self.object = "text_completion.chunk" @@ -526,9 +594,10 @@ class TextCompletionResponse(OpenAIObject): self.usage = usage else: self.usage = Usage() - self._hidden_params = {} # used in case users want to access the original model response + self._hidden_params = ( + {} + ) # used in case users want to access the original model response - def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) @@ -536,7 +605,7 @@ class TextCompletionResponse(OpenAIObject): def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -548,7 +617,7 @@ class TextCompletionResponse(OpenAIObject): class ImageResponse(OpenAIObject): created: Optional[int] = None - + data: Optional[list] = None def __init__(self, created=None, data=None, response_ms=None): @@ -556,11 +625,11 @@ class ImageResponse(OpenAIObject): _response_ms = response_ms else: _response_ms = None - if data: + if data: data = data - else: + else: data = None - + if created: created = created else: @@ -571,11 +640,11 @@ class ImageResponse(OpenAIObject): def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) - + def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) - + def __getitem__(self, key): # Allow dictionary-style access to attributes return getattr(self, key) @@ -583,52 +652,69 @@ class ImageResponse(OpenAIObject): def __setitem__(self, key, value): # Allow dictionary-style assignment of attributes setattr(self, key, value) - + def json(self, **kwargs): try: - return self.model_dump() # noqa + return self.model_dump() # noqa except: # if using pydantic v1 return self.dict() - + + ############################################################ def print_verbose(print_statement): try: if litellm.set_verbose: - print(print_statement) # noqa + print(print_statement) # noqa except: pass + ####### LOGGING ################### from enum import Enum + class CallTypes(Enum): - embedding = 'embedding' - completion = 'completion' - acompletion = 'acompletion' - aembedding = 'aembedding' - image_generation = 'image_generation' - aimage_generation = 'aimage_generation' + embedding = "embedding" + completion = "completion" + acompletion = "acompletion" + aembedding = "aembedding" + image_generation = "image_generation" + aimage_generation = "aimage_generation" + # Logging function -> log the exact model details + what's being sent | Non-Blocking class Logging: global supabaseClient, liteDebuggerClient, promptLayerLogger, weightsBiasesLogger, langsmithLogger, capture_exception, add_breadcrumb, llmonitorLogger - def __init__(self, model, messages, stream, call_type, start_time, litellm_call_id, function_id): + def __init__( + self, + model, + messages, + stream, + call_type, + start_time, + litellm_call_id, + function_id, + ): if call_type not in [item.value for item in CallTypes]: allowed_values = ", ".join([item.value for item in CallTypes]) - raise ValueError(f"Invalid call_type {call_type}. Allowed values: {allowed_values}") + raise ValueError( + f"Invalid call_type {call_type}. Allowed values: {allowed_values}" + ) self.model = model self.messages = messages self.stream = stream - self.start_time = start_time # log the call start time + self.start_time = start_time # log the call start time self.call_type = call_type self.litellm_call_id = litellm_call_id self.function_id = function_id - self.streaming_chunks = [] # for generating complete stream response + self.streaming_chunks = [] # for generating complete stream response self.model_call_details = {} - - def update_environment_variables(self, model, user, optional_params, litellm_params, **additional_params): + + def update_environment_variables( + self, model, user, optional_params, litellm_params, **additional_params + ): self.optional_params = optional_params self.model = model self.user = user @@ -645,10 +731,10 @@ class Logging: "user": user, "call_type": str(self.call_type), **self.optional_params, - **additional_params + **additional_params, } - def _pre_call(self, input, api_key, model=None, additional_args={}): + def _pre_call(self, input, api_key, model=None, additional_args={}): """ Common helper function across the sync + async pre-call function """ @@ -658,31 +744,43 @@ class Logging: self.model_call_details["additional_args"] = additional_args self.model_call_details["log_event_type"] = "pre_api_call" if ( - model - ): # if model name was changes pre-call, overwrite the initial model call name with the new one - self.model_call_details["model"] = model + model + ): # if model name was changes pre-call, overwrite the initial model call name with the new one + self.model_call_details["model"] = model def pre_call(self, input, api_key, model=None, additional_args={}): # Log the exact input to the LLM API - litellm.error_logs['PRE_CALL'] = locals() + litellm.error_logs["PRE_CALL"] = locals() try: - self._pre_call(input=input, api_key=api_key, model=model, additional_args=additional_args) + self._pre_call( + input=input, + api_key=api_key, + model=model, + additional_args=additional_args, + ) # User Logging -> if you pass in a custom logging function headers = additional_args.get("headers", {}) - if headers is None: + if headers is None: headers = {} data = additional_args.get("complete_input_dict", {}) api_base = additional_args.get("api_base", "") - masked_headers = {k: (v[:-20] + '*' * 20) if (isinstance(v, str) and len(v) > 20) else v for k, v in headers.items()} - formatted_headers = " ".join([f"-H '{k}: {v}'" for k, v in masked_headers.items()]) + masked_headers = { + k: (v[:-20] + "*" * 20) if (isinstance(v, str) and len(v) > 20) else v + for k, v in headers.items() + } + formatted_headers = " ".join( + [f"-H '{k}: {v}'" for k, v in masked_headers.items()] + ) print_verbose(f"PRE-API-CALL ADDITIONAL ARGS: {additional_args}") curl_command = "\n\nPOST Request Sent from LiteLLM:\n" curl_command += "curl -X POST \\\n" curl_command += f"{api_base} \\\n" - curl_command += f"{formatted_headers} \\\n" if formatted_headers.strip() != "" else "" + curl_command += ( + f"{formatted_headers} \\\n" if formatted_headers.strip() != "" else "" + ) curl_command += f"-d '{str(data)}'\n" if additional_args.get("request_str", None) is not None: # print the sagemaker / bedrock client request @@ -703,10 +801,17 @@ class Logging: if litellm.max_budget and self.stream: start_time = self.start_time - end_time = self.start_time # no time has passed as the call hasn't been made yet + end_time = ( + self.start_time + ) # no time has passed as the call hasn't been made yet time_diff = (end_time - start_time).total_seconds() float_diff = float(time_diff) - litellm._current_cost += litellm.completion_cost(model=self.model, prompt="".join(message["content"] for message in self.messages), completion="", total_time=float_diff) + litellm._current_cost += litellm.completion_cost( + model=self.model, + prompt="".join(message["content"] for message in self.messages), + completion="", + total_time=float_diff, + ) # Input Integration Logging -> If you want to log the fact that an attempt to call the model was made for callback in litellm.input_callback: @@ -725,7 +830,9 @@ class Logging: ) elif callback == "lite_debugger": - print_verbose(f"reaches litedebugger for logging! - model_call_details {self.model_call_details}") + print_verbose( + f"reaches litedebugger for logging! - model_call_details {self.model_call_details}" + ) model = self.model_call_details["model"] messages = self.model_call_details["input"] print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") @@ -737,7 +844,7 @@ class Logging: litellm_params=self.model_call_details["litellm_params"], optional_params=self.model_call_details["optional_params"], print_verbose=print_verbose, - call_type=self.call_type + call_type=self.call_type, ) elif callback == "sentry" and add_breadcrumb: print_verbose("reaches sentry breadcrumbing") @@ -746,19 +853,19 @@ class Logging: message=f"Model Call Details pre-call: {self.model_call_details}", level="info", ) - elif isinstance(callback, CustomLogger): # custom logger class + elif isinstance(callback, CustomLogger): # custom logger class callback.log_pre_api_call( model=self.model, messages=self.messages, kwargs=self.model_call_details, ) - elif callable(callback): # custom logger functions + elif callable(callback): # custom logger functions customLogger.log_input_event( model=self.model, messages=self.messages, kwargs=self.model_call_details, print_verbose=print_verbose, - callback_func=callback + callback_func=callback, ) except Exception as e: traceback.print_exc() @@ -780,37 +887,48 @@ class Logging: if capture_exception: # log this error to sentry for debugging capture_exception(e) - async def async_pre_call(self, result=None, start_time=None, end_time=None, **kwargs): + async def async_pre_call( + self, result=None, start_time=None, end_time=None, **kwargs + ): """ - Â Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. + Â Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. """ - start_time, end_time, result = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result) + start_time, end_time, result = self._success_handler_helper_fn( + start_time=start_time, end_time=end_time, result=result + ) print_verbose(f"Async input callbacks: {litellm._async_input_callback}") for callback in litellm._async_input_callback: - try: - if isinstance(callback, CustomLogger): # custom logger class + try: + if isinstance(callback, CustomLogger): # custom logger class print_verbose(f"Async input callbacks: CustomLogger") - asyncio.create_task(callback.async_log_input_event( + asyncio.create_task( + callback.async_log_input_event( model=self.model, messages=self.messages, kwargs=self.model_call_details, - )) - if callable(callback): # custom logger functions + ) + ) + if callable(callback): # custom logger functions print_verbose(f"Async success callbacks: async_log_event") - asyncio.create_task(customLogger.async_log_input_event( - model=self.model, - messages=self.messages, - kwargs=self.model_call_details, - print_verbose=print_verbose, - callback_func=callback - )) - except: + asyncio.create_task( + customLogger.async_log_input_event( + model=self.model, + messages=self.messages, + kwargs=self.model_call_details, + print_verbose=print_verbose, + callback_func=callback, + ) + ) + except: print_verbose( f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}" ) - def post_call(self, original_response, input=None, api_key=None, additional_args={}): + + def post_call( + self, original_response, input=None, api_key=None, additional_args={} + ): # Log the exact result from the LLM API, for streaming - log the type of response received - litellm.error_logs['POST_CALL'] = locals() + litellm.error_logs["POST_CALL"] = locals() try: self.model_call_details["input"] = input self.model_call_details["api_key"] = api_key @@ -819,11 +937,15 @@ class Logging: self.model_call_details["log_event_type"] = "post_api_call" # User Logging -> if you pass in a custom logging function - print_verbose(f"RAW RESPONSE:\n{self.model_call_details.get('original_response', self.model_call_details)}\n\n") + print_verbose( + f"RAW RESPONSE:\n{self.model_call_details.get('original_response', self.model_call_details)}\n\n" + ) print_verbose( f"Logging Details Post-API Call: logger_fn - {self.logger_fn} | callable(logger_fn) - {callable(self.logger_fn)}" ) - print_verbose(f"Logging Details Post-API Call: LiteLLM Params: {self.model_call_details}") + print_verbose( + f"Logging Details Post-API Call: LiteLLM Params: {self.model_call_details}" + ) if self.logger_fn and callable(self.logger_fn): try: self.logger_fn( @@ -833,7 +955,7 @@ class Logging: print_verbose( f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}" ) - + # Input Integration Logging -> If you want to log the fact that an attempt to call the model was made for callback in litellm.input_callback: try: @@ -844,8 +966,8 @@ class Logging: original_response=original_response, litellm_call_id=self.litellm_params["litellm_call_id"], print_verbose=print_verbose, - call_type = self.call_type, - stream = self.stream, + call_type=self.call_type, + stream=self.stream, ) elif callback == "sentry" and add_breadcrumb: print_verbose("reaches sentry breadcrumbing") @@ -854,12 +976,12 @@ class Logging: message=f"Model Call Details post-call: {self.model_call_details}", level="info", ) - elif isinstance(callback, CustomLogger): # custom logger class + elif isinstance(callback, CustomLogger): # custom logger class callback.log_post_api_call( kwargs=self.model_call_details, response_obj=None, start_time=self.start_time, - end_time=None + end_time=None, ) except Exception as e: print_verbose( @@ -875,9 +997,11 @@ class Logging: f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}" ) pass - - def _success_handler_helper_fn(self, result=None, start_time=None, end_time=None, cache_hit=None): - try: + + def _success_handler_helper_fn( + self, result=None, start_time=None, end_time=None, cache_hit=None + ): + try: if start_time is None: start_time = self.start_time if end_time is None: @@ -889,42 +1013,67 @@ class Logging: if litellm.max_budget and self.stream: time_diff = (end_time - start_time).total_seconds() float_diff = float(time_diff) - litellm._current_cost += litellm.completion_cost(model=self.model, prompt="", completion=result["content"], total_time=float_diff) + litellm._current_cost += litellm.completion_cost( + model=self.model, + prompt="", + completion=result["content"], + total_time=float_diff, + ) return start_time, end_time, result - except Exception as e: + except Exception as e: print_verbose(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}") - def success_handler(self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs): - print_verbose( - f"Logging Details LiteLLM-Success Call" - ) + def success_handler( + self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs + ): + print_verbose(f"Logging Details LiteLLM-Success Call") # print(f"original response in success handler: {self.model_call_details['original_response']}") try: - print_verbose(f"success callbacks: {litellm.success_callback}") + print_verbose(f"success callbacks: {litellm.success_callback}") ## BUILD COMPLETE STREAMED RESPONSE complete_streaming_response = None - if self.stream and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # only call stream chunk builder if it's not acompletion() - if result.choices[0].finish_reason is not None: # if it's the last chunk + if ( + self.stream + and self.model_call_details.get("litellm_params", {}).get( + "acompletion", False + ) + == False + ): # only call stream chunk builder if it's not acompletion() + if ( + result.choices[0].finish_reason is not None + ): # if it's the last chunk self.streaming_chunks.append(result) # print_verbose(f"final set of received chunks: {self.streaming_chunks}") - try: - complete_streaming_response = litellm.stream_chunk_builder(self.streaming_chunks, messages=self.model_call_details.get("messages", None)) - except: + try: + complete_streaming_response = litellm.stream_chunk_builder( + self.streaming_chunks, + messages=self.model_call_details.get("messages", None), + ) + except: complete_streaming_response = None else: self.streaming_chunks.append(result) - if complete_streaming_response: - self.model_call_details["complete_streaming_response"] = complete_streaming_response + if complete_streaming_response: + self.model_call_details[ + "complete_streaming_response" + ] = complete_streaming_response - start_time, end_time, result = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit) + start_time, end_time, result = self._success_handler_helper_fn( + start_time=start_time, + end_time=end_time, + result=result, + cache_hit=cache_hit, + ) for callback in litellm.success_callback: try: if callback == "lite_debugger": print_verbose("reaches lite_debugger for logging!") print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") - print_verbose(f"liteDebuggerClient details function {self.call_type} and stream set to {self.stream}") + print_verbose( + f"liteDebuggerClient details function {self.call_type} and stream set to {self.stream}" + ) liteDebuggerClient.log_event( end_user=kwargs.get("user", "default"), response_obj=result, @@ -932,8 +1081,8 @@ class Logging: end_time=end_time, litellm_call_id=self.litellm_call_id, print_verbose=print_verbose, - call_type = self.call_type, - stream = self.stream, + call_type=self.call_type, + stream=self.stream, ) if callback == "promptlayer": print_verbose("reaches promptlayer for logging!") @@ -946,8 +1095,8 @@ class Logging: ) if callback == "supabase": print_verbose("reaches supabase for logging!") - kwargs=self.model_call_details - + kwargs = self.model_call_details + # this only logs streaming once, complete_streaming_response exists i.e when stream ends if self.stream: if "complete_streaming_response" not in kwargs: @@ -955,7 +1104,7 @@ class Logging: else: print_verbose("reaches supabase for streaming logging!") result = kwargs["complete_streaming_response"] - + model = kwargs["model"] messages = kwargs["messages"] optional_params = kwargs.get("optional_params", {}) @@ -967,7 +1116,9 @@ class Logging: response_obj=result, start_time=start_time, end_time=end_time, - litellm_call_id=litellm_params.get("litellm_call_id", str(uuid.uuid4())), + litellm_call_id=litellm_params.get( + "litellm_call_id", str(uuid.uuid4()) + ), print_verbose=print_verbose, ) if callback == "wandb": @@ -992,10 +1143,16 @@ class Logging: print_verbose("reaches llmonitor for logging!") model = self.model - input = self.model_call_details.get("messages", self.model_call_details.get("input", None)) + input = self.model_call_details.get( + "messages", self.model_call_details.get("input", None) + ) # if contains input, it's 'embedding', otherwise 'llm' - type = "embed" if self.call_type == CallTypes.embedding.value else "llm" + type = ( + "embed" + if self.call_type == CallTypes.embedding.value + else "llm" + ) llmonitorLogger.log_event( type=type, @@ -1025,8 +1182,10 @@ class Logging: global langFuseLogger print_verbose("reaches langfuse for logging!") kwargs = {} - for k, v in self.model_call_details.items(): - if k != "original_response": # copy.deepcopy raises errors as this could be a coroutine + for k, v in self.model_call_details.items(): + if ( + k != "original_response" + ): # copy.deepcopy raises errors as this could be a coroutine kwargs[k] = v # this only logs streaming once, complete_streaming_response exists i.e when stream ends if self.stream: @@ -1042,6 +1201,7 @@ class Logging: response_obj=result, start_time=start_time, end_time=end_time, + user_id=self.model_call_details.get("user", "default"), print_verbose=print_verbose, ) if callback == "cache" and litellm.cache is not None: @@ -1050,17 +1210,21 @@ class Logging: kwargs = self.model_call_details if self.stream: if "complete_streaming_response" not in kwargs: - print_verbose(f"success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n") + print_verbose( + f"success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n" + ) return else: - print_verbose("success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache") + print_verbose( + "success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache" + ) result = kwargs["complete_streaming_response"] # only add to cache once we have a complete streaming response litellm.cache.add_cache(result, **kwargs) if callback == "traceloop": deep_copy = {} - for k, v in self.model_call_details.items(): - if k != "original_response": + for k, v in self.model_call_details.items(): + if k != "original_response": deep_copy[k] = v traceloopLogger.log_event( kwargs=deep_copy, @@ -1069,18 +1233,32 @@ class Logging: end_time=end_time, print_verbose=print_verbose, ) - elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False and self.model_call_details.get("litellm_params", {}).get("aembedding", False) == False: # custom logger class - print_verbose(f"success callbacks: Running Custom Logger Class") + elif ( + isinstance(callback, CustomLogger) + and self.model_call_details.get("litellm_params", {}).get( + "acompletion", False + ) + == False + and self.model_call_details.get("litellm_params", {}).get( + "aembedding", False + ) + == False + ): # custom logger class + print_verbose(f"success callbacks: Running Custom Logger Class") if self.stream and complete_streaming_response is None: callback.log_stream_event( kwargs=self.model_call_details, response_obj=result, start_time=start_time, - end_time=end_time - ) + end_time=end_time, + ) else: if self.stream and complete_streaming_response: - self.model_call_details["complete_response"] = self.model_call_details.get("complete_streaming_response", {}) + self.model_call_details[ + "complete_response" + ] = self.model_call_details.get( + "complete_streaming_response", {} + ) result = self.model_call_details["complete_response"] callback.log_success_event( kwargs=self.model_call_details, @@ -1088,15 +1266,17 @@ class Logging: start_time=start_time, end_time=end_time, ) - if callable(callback): # custom logger functions - print_verbose(f"success callbacks: Running Custom Callback Function") + if callable(callback): # custom logger functions + print_verbose( + f"success callbacks: Running Custom Callback Function" + ) customLogger.log_event( kwargs=self.model_call_details, response_obj=result, start_time=start_time, end_time=end_time, print_verbose=print_verbose, - callback_func=callback + callback_func=callback, ) except Exception as e: @@ -1114,60 +1294,77 @@ class Logging: ) pass - async def async_success_handler(self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs): + async def async_success_handler( + self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs + ): """ Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. """ print_verbose(f"Async success callbacks: {litellm._async_success_callback}") ## BUILD COMPLETE STREAMED RESPONSE complete_streaming_response = None - if self.stream: - if result.choices[0].finish_reason is not None: # if it's the last chunk + if self.stream: + if result.choices[0].finish_reason is not None: # if it's the last chunk self.streaming_chunks.append(result) # print_verbose(f"final set of received chunks: {self.streaming_chunks}") - try: - complete_streaming_response = litellm.stream_chunk_builder(self.streaming_chunks, messages=self.model_call_details.get("messages", None)) + try: + complete_streaming_response = litellm.stream_chunk_builder( + self.streaming_chunks, + messages=self.model_call_details.get("messages", None), + ) except Exception as e: - print_verbose(f"Error occurred building stream chunk: {traceback.format_exc()}") + print_verbose( + f"Error occurred building stream chunk: {traceback.format_exc()}" + ) complete_streaming_response = None else: self.streaming_chunks.append(result) - if complete_streaming_response: + if complete_streaming_response: print_verbose("Async success callbacks: Got a complete streaming response") - self.model_call_details["complete_streaming_response"] = complete_streaming_response - start_time, end_time, result = self._success_handler_helper_fn(start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit) + self.model_call_details[ + "complete_streaming_response" + ] = complete_streaming_response + start_time, end_time, result = self._success_handler_helper_fn( + start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit + ) for callback in litellm._async_success_callback: - try: + try: if callback == "cache" and litellm.cache is not None: # set_cache once complete streaming response is built print_verbose("async success_callback: reaches cache for logging!") kwargs = self.model_call_details if self.stream: if "complete_streaming_response" not in kwargs: - print_verbose(f"async success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n") + print_verbose( + f"async success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n" + ) return else: - print_verbose("async success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache") + print_verbose( + "async success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache" + ) result = kwargs["complete_streaming_response"] # only add to cache once we have a complete streaming response litellm.cache.add_cache(result, **kwargs) - if isinstance(callback, CustomLogger): # custom logger class + if isinstance(callback, CustomLogger): # custom logger class print_verbose(f"Async success callbacks: CustomLogger") if self.stream: if "complete_streaming_response" in self.model_call_details: await callback.async_log_success_event( kwargs=self.model_call_details, - response_obj=self.model_call_details["complete_streaming_response"], + response_obj=self.model_call_details[ + "complete_streaming_response" + ], start_time=start_time, end_time=end_time, ) - else: - await callback.async_log_stream_event( # [TODO]: move this to being an async log stream event function + else: + await callback.async_log_stream_event( # [TODO]: move this to being an async log stream event function kwargs=self.model_call_details, response_obj=result, start_time=start_time, - end_time=end_time - ) + end_time=end_time, + ) else: await callback.async_log_success_event( kwargs=self.model_call_details, @@ -1175,7 +1372,7 @@ class Logging: start_time=start_time, end_time=end_time, ) - if callable(callback): # custom logger functions + if callable(callback): # custom logger functions print_verbose(f"Async success callbacks: async_log_event") await customLogger.async_log_event( kwargs=self.model_call_details, @@ -1183,7 +1380,7 @@ class Logging: start_time=start_time, end_time=end_time, print_verbose=print_verbose, - callback_func=callback + callback_func=callback, ) if callback == "dynamodb": global dynamoLogger @@ -1191,16 +1388,22 @@ class Logging: dynamoLogger = DyanmoDBLogger() if self.stream: if "complete_streaming_response" in self.model_call_details: - print_verbose("DynamoDB Logger: Got Stream Event - Completed Stream Response") + print_verbose( + "DynamoDB Logger: Got Stream Event - Completed Stream Response" + ) await dynamoLogger._async_log_event( kwargs=self.model_call_details, - response_obj=self.model_call_details["complete_streaming_response"], + response_obj=self.model_call_details[ + "complete_streaming_response" + ], start_time=start_time, end_time=end_time, - print_verbose=print_verbose + print_verbose=print_verbose, + ) + else: + print_verbose( + "DynamoDB Logger: Got Stream Event - No complete stream response as yet" ) - else: - print_verbose("DynamoDB Logger: Got Stream Event - No complete stream response as yet") else: await dynamoLogger._async_log_event( kwargs=self.model_call_details, @@ -1213,8 +1416,10 @@ class Logging: global langFuseLogger print_verbose("reaches langfuse for logging!") kwargs = {} - for k, v in self.model_call_details.items(): - if k != "original_response": # copy.deepcopy raises errors as this could be a coroutine + for k, v in self.model_call_details.items(): + if ( + k != "original_response" + ): # copy.deepcopy raises errors as this could be a coroutine kwargs[k] = v # this only logs streaming once, complete_streaming_response exists i.e when stream ends if self.stream: @@ -1227,18 +1432,21 @@ class Logging: langFuseLogger = LangFuseLogger() await langFuseLogger._async_log_event( kwargs=kwargs, + user_id=self.model_call_details.get("user", "default"), response_obj=result, start_time=start_time, end_time=end_time, print_verbose=print_verbose, ) - except: + except: print_verbose( f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}" ) pass - def _failure_handler_helper_fn(self, exception, traceback_exception, start_time=None, end_time=None): + def _failure_handler_helper_fn( + self, exception, traceback_exception, start_time=None, end_time=None + ): if start_time is None: start_time = self.start_time if end_time is None: @@ -1255,49 +1463,58 @@ class Logging: self.model_call_details.setdefault("original_response", None) return start_time, end_time - def failure_handler(self, exception, traceback_exception, start_time=None, end_time=None): - print_verbose( - f"Logging Details LiteLLM-Failure Call" - ) + def failure_handler( + self, exception, traceback_exception, start_time=None, end_time=None + ): + print_verbose(f"Logging Details LiteLLM-Failure Call") try: - start_time, end_time = self._failure_handler_helper_fn(exception=exception, traceback_exception=traceback_exception, start_time=start_time, end_time=end_time) - result = None # result sent to all loggers, init this to None incase it's not created + start_time, end_time = self._failure_handler_helper_fn( + exception=exception, + traceback_exception=traceback_exception, + start_time=start_time, + end_time=end_time, + ) + result = None # result sent to all loggers, init this to None incase it's not created for callback in litellm.failure_callback: try: if callback == "lite_debugger": - print_verbose("reaches lite_debugger for logging!") - print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") - result = { - "model": self.model, - "created": time.time(), - "error": traceback_exception, - "usage": { - "prompt_tokens": prompt_token_calculator( - self.model, messages=self.messages - ), - "completion_tokens": 0, - }, - } - liteDebuggerClient.log_event( - model=self.model, - messages=self.messages, - end_user=self.model_call_details.get("user", "default"), - response_obj=result, - start_time=start_time, - end_time=end_time, - litellm_call_id=self.litellm_call_id, - print_verbose=print_verbose, - call_type = self.call_type, - stream = self.stream, - ) + print_verbose("reaches lite_debugger for logging!") + print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") + result = { + "model": self.model, + "created": time.time(), + "error": traceback_exception, + "usage": { + "prompt_tokens": prompt_token_calculator( + self.model, messages=self.messages + ), + "completion_tokens": 0, + }, + } + liteDebuggerClient.log_event( + model=self.model, + messages=self.messages, + end_user=self.model_call_details.get("user", "default"), + response_obj=result, + start_time=start_time, + end_time=end_time, + litellm_call_id=self.litellm_call_id, + print_verbose=print_verbose, + call_type=self.call_type, + stream=self.stream, + ) elif callback == "llmonitor": print_verbose("reaches llmonitor for logging error!") model = self.model input = self.model_call_details["input"] - - type = "embed" if self.call_type == CallTypes.embedding.value else "llm" + + type = ( + "embed" + if self.call_type == CallTypes.embedding.value + else "llm" + ) llmonitorLogger.log_event( type=type, @@ -1316,17 +1533,29 @@ class Logging: if capture_exception: capture_exception(exception) else: - print_verbose(f"capture exception not initialized: {capture_exception}") - elif callable(callback): # custom logger functions + print_verbose( + f"capture exception not initialized: {capture_exception}" + ) + elif callable(callback): # custom logger functions customLogger.log_event( kwargs=self.model_call_details, response_obj=result, start_time=start_time, end_time=end_time, print_verbose=print_verbose, - callback_func=callback + callback_func=callback, ) - elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False and self.model_call_details.get("litellm_params", {}).get("aembedding", False) == False: # custom logger class + elif ( + isinstance(callback, CustomLogger) + and self.model_call_details.get("litellm_params", {}).get( + "acompletion", False + ) + == False + and self.model_call_details.get("litellm_params", {}).get( + "aembedding", False + ) + == False + ): # custom logger class callback.log_failure_event( start_time=start_time, end_time=end_time, @@ -1348,37 +1577,43 @@ class Logging: ) pass - async def async_failure_handler(self, exception, traceback_exception, start_time=None, end_time=None): + async def async_failure_handler( + self, exception, traceback_exception, start_time=None, end_time=None + ): """ Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. """ - start_time, end_time = self._failure_handler_helper_fn(exception=exception, traceback_exception=traceback_exception, start_time=start_time, end_time=end_time) - result = None # result sent to all loggers, init this to None incase it's not created + start_time, end_time = self._failure_handler_helper_fn( + exception=exception, + traceback_exception=traceback_exception, + start_time=start_time, + end_time=end_time, + ) + result = None # result sent to all loggers, init this to None incase it's not created for callback in litellm._async_failure_callback: - try: - if isinstance(callback, CustomLogger): # custom logger class + try: + if isinstance(callback, CustomLogger): # custom logger class await callback.async_log_failure_event( kwargs=self.model_call_details, response_obj=result, start_time=start_time, end_time=end_time, ) - if callable(callback): # custom logger functions + if callable(callback): # custom logger functions await customLogger.async_log_event( kwargs=self.model_call_details, response_obj=result, start_time=start_time, end_time=end_time, print_verbose=print_verbose, - callback_func=callback - ) - except Exception as e: + callback_func=callback, + ) + except Exception as e: print_verbose( f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}" ) - def exception_logging( additional_args={}, logger_fn=None, @@ -1411,53 +1646,59 @@ def exception_logging( ####### RULES ################### -class Rules: + +class Rules: """ Fail calls based on the input or llm api output - Example usage: - import litellm - def my_custom_rule(input): # receives the model response - if "i don't think i can answer" in input: # trigger fallback if the model refuses to answer - return False - return True - + Example usage: + import litellm + def my_custom_rule(input): # receives the model response + if "i don't think i can answer" in input: # trigger fallback if the model refuses to answer + return False + return True + litellm.post_call_rules = [my_custom_rule] # have these be functions that can be called to fail a call - response = litellm.completion(model="gpt-3.5-turbo", messages=[{"role": "user", - "content": "Hey, how's it going?"}], fallbacks=["openrouter/mythomax"]) + response = litellm.completion(model="gpt-3.5-turbo", messages=[{"role": "user", + "content": "Hey, how's it going?"}], fallbacks=["openrouter/mythomax"]) """ + def __init__(self) -> None: pass - def pre_call_rules(self, input: str, model: str): - for rule in litellm.pre_call_rules: - if callable(rule): + def pre_call_rules(self, input: str, model: str): + for rule in litellm.pre_call_rules: + if callable(rule): decision = rule(input) if decision is False: - raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore - return True + raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore + return True - def post_call_rules(self, input: str, model: str): - for rule in litellm.post_call_rules: - if callable(rule): + def post_call_rules(self, input: str, model: str): + for rule in litellm.post_call_rules: + if callable(rule): decision = rule(input) if decision is False: - raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore - return True + raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider="", model=model) # type: ignore + return True + ####### CLIENT ################### # make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking def client(original_function): global liteDebuggerClient, get_all_keys rules_obj = Rules() + def function_setup( start_time, *args, **kwargs ): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc. try: global callback_list, add_breadcrumb, user_logger_fn, Logging function_id = kwargs["id"] if "id" in kwargs else None - if litellm.use_client or ("use_client" in kwargs and kwargs["use_client"] == True): + if litellm.use_client or ( + "use_client" in kwargs and kwargs["use_client"] == True + ): print_verbose(f"litedebugger initialized") if "lite_debugger" not in litellm.input_callback: litellm.input_callback.append("lite_debugger") @@ -1465,8 +1706,8 @@ def client(original_function): litellm.success_callback.append("lite_debugger") if "lite_debugger" not in litellm.failure_callback: litellm.failure_callback.append("lite_debugger") - if len(litellm.callbacks) > 0: - for callback in litellm.callbacks: + if len(litellm.callbacks) > 0: + for callback in litellm.callbacks: if callback not in litellm.input_callback: litellm.input_callback.append(callback) if callback not in litellm.success_callback: @@ -1477,7 +1718,9 @@ def client(original_function): litellm._async_success_callback.append(callback) if callback not in litellm._async_failure_callback: litellm._async_failure_callback.append(callback) - print_verbose(f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}") + print_verbose( + f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}" + ) if ( len(litellm.input_callback) > 0 or len(litellm.success_callback) > 0 @@ -1490,10 +1733,7 @@ def client(original_function): + litellm.failure_callback ) ) - set_callbacks( - callback_list=callback_list, - function_id=function_id - ) + set_callbacks(callback_list=callback_list, function_id=function_id) ## ASYNC CALLBACKS if len(litellm.input_callback) > 0: removed_async_items = [] @@ -1506,10 +1746,10 @@ def client(original_function): for index in reversed(removed_async_items): litellm.input_callback.pop(index) - if len(litellm.success_callback) > 0: + if len(litellm.success_callback) > 0: removed_async_items = [] - for index, callback in enumerate(litellm.success_callback): - if inspect.iscoroutinefunction(callback): + for index, callback in enumerate(litellm.success_callback): + if inspect.iscoroutinefunction(callback): litellm._async_success_callback.append(callback) removed_async_items.append(index) elif callback == "dynamodb": @@ -1517,7 +1757,9 @@ def client(original_function): # we only support async dynamo db logging for acompletion/aembedding since that's used on proxy litellm._async_success_callback.append(callback) removed_async_items.append(index) - elif callback == "langfuse" and inspect.iscoroutinefunction(original_function): + elif callback == "langfuse" and inspect.iscoroutinefunction( + original_function + ): # use async success callback for langfuse if this is litellm.acompletion(). Streaming logging does not work otherwise litellm._async_success_callback.append(callback) removed_async_items.append(index) @@ -1525,11 +1767,11 @@ def client(original_function): # Pop the async items from success_callback in reverse order to avoid index issues for index in reversed(removed_async_items): litellm.success_callback.pop(index) - - if len(litellm.failure_callback) > 0: + + if len(litellm.failure_callback) > 0: removed_async_items = [] - for index, callback in enumerate(litellm.failure_callback): - if inspect.iscoroutinefunction(callback): + for index, callback in enumerate(litellm.failure_callback): + if inspect.iscoroutinefunction(callback): litellm._async_failure_callback.append(callback) removed_async_items.append(index) @@ -1549,35 +1791,70 @@ def client(original_function): # INIT LOGGER - for user-specified integrations model = args[0] if len(args) > 0 else kwargs.get("model", None) call_type = original_function.__name__ - if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value: + if ( + call_type == CallTypes.completion.value + or call_type == CallTypes.acompletion.value + ): messages = None if len(args) > 1: - messages = args[1] + messages = args[1] elif kwargs.get("messages", None): messages = kwargs["messages"] - ### PRE-CALL RULES ### - if isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], dict) and "content" in messages[0]: - rules_obj.pre_call_rules(input="".join(m["content"] for m in messages if isinstance(m["content"], str)), model=model) - elif call_type == CallTypes.embedding.value or call_type == CallTypes.aembedding.value: + ### PRE-CALL RULES ### + if ( + isinstance(messages, list) + and len(messages) > 0 + and isinstance(messages[0], dict) + and "content" in messages[0] + ): + rules_obj.pre_call_rules( + input="".join( + m["content"] + for m in messages + if isinstance(m["content"], str) + ), + model=model, + ) + elif ( + call_type == CallTypes.embedding.value + or call_type == CallTypes.aembedding.value + ): messages = args[1] if len(args) > 1 else kwargs["input"] - elif call_type == CallTypes.image_generation.value or call_type == CallTypes.aimage_generation.value: + elif ( + call_type == CallTypes.image_generation.value + or call_type == CallTypes.aimage_generation.value + ): messages = args[0] if len(args) > 0 else kwargs["prompt"] stream = True if "stream" in kwargs and kwargs["stream"] == True else False - logging_obj = Logging(model=model, messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], function_id=function_id, call_type=call_type, start_time=start_time) + logging_obj = Logging( + model=model, + messages=messages, + stream=stream, + litellm_call_id=kwargs["litellm_call_id"], + function_id=function_id, + call_type=call_type, + start_time=start_time, + ) return logging_obj - except Exception as e: + except Exception as e: import logging - logging.debug(f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}") + + logging.debug( + f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}" + ) raise e - + def post_call_processing(original_response, model): - try: + try: call_type = original_function.__name__ - if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value: - model_response = original_response['choices'][0]['message']['content'] - ### POST-CALL RULES ### + if ( + call_type == CallTypes.completion.value + or call_type == CallTypes.acompletion.value + ): + model_response = original_response["choices"][0]["message"]["content"] + ### POST-CALL RULES ### rules_obj.post_call_rules(input=model_response, model=model) - except Exception as e: + except Exception as e: raise e def crash_reporting(*args, **kwargs): @@ -1610,7 +1887,7 @@ def client(original_function): try: model = args[0] if len(args) > 0 else kwargs["model"] except: - call_type = original_function.__name__ + call_type = original_function.__name__ if call_type != CallTypes.image_generation.value: raise ValueError("model param not passed in.") @@ -1619,84 +1896,130 @@ def client(original_function): logging_obj = function_setup(start_time, *args, **kwargs) kwargs["litellm_logging_obj"] = logging_obj - # [OPTIONAL] CHECK BUDGET + # [OPTIONAL] CHECK BUDGET if litellm.max_budget: if litellm._current_cost > litellm.max_budget: - raise BudgetExceededError(current_cost=litellm._current_cost, max_budget=litellm.max_budget) + raise BudgetExceededError( + current_cost=litellm._current_cost, + max_budget=litellm.max_budget, + ) # [OPTIONAL] CHECK CACHE - print_verbose(f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}") - # if caching is false, don't run this - if (kwargs.get("caching", None) is None and litellm.cache is not None) or kwargs.get("caching", False) == True: # allow users to control returning cached responses from the completion function + print_verbose( + f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}" + ) + # if caching is false, don't run this + if ( + kwargs.get("caching", None) is None and litellm.cache is not None + ) or kwargs.get( + "caching", False + ) == True: # allow users to control returning cached responses from the completion function # checking cache print_verbose(f"INSIDE CHECKING CACHE") - if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types: + if ( + litellm.cache is not None + and str(original_function.__name__) + in litellm.cache.supported_call_types + ): print_verbose(f"Checking Cache") preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) - kwargs["preset_cache_key"] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key + kwargs[ + "preset_cache_key" + ] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key cached_result = litellm.cache.get_cache(*args, **kwargs) if cached_result != None: - if "detail" in cached_result: - # implies an error occurred + if "detail" in cached_result: + # implies an error occurred pass - else: + else: call_type = original_function.__name__ - print_verbose(f"Cache Response Object routing: call_type - {call_type}; cached_result instace: {type(cached_result)}") - if call_type == CallTypes.completion.value and isinstance(cached_result, dict): - return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse(), stream = kwargs.get("stream", False)) - elif call_type == CallTypes.embedding.value and isinstance(cached_result, dict): - return convert_to_model_response_object(response_object=cached_result, response_type="embedding") - else: + print_verbose( + f"Cache Response Object routing: call_type - {call_type}; cached_result instace: {type(cached_result)}" + ) + if call_type == CallTypes.completion.value and isinstance( + cached_result, dict + ): + return convert_to_model_response_object( + response_object=cached_result, + model_response_object=ModelResponse(), + stream=kwargs.get("stream", False), + ) + elif call_type == CallTypes.embedding.value and isinstance( + cached_result, dict + ): + return convert_to_model_response_object( + response_object=cached_result, + response_type="embedding", + ) + else: return cached_result # MODEL CALL result = original_function(*args, **kwargs) end_time = datetime.datetime.now() if "stream" in kwargs and kwargs["stream"] == True: # TODO: Add to cache for streaming - if "complete_response" in kwargs and kwargs["complete_response"] == True: + if ( + "complete_response" in kwargs + and kwargs["complete_response"] == True + ): chunks = [] for idx, chunk in enumerate(result): chunks.append(chunk) - return litellm.stream_chunk_builder(chunks, messages=kwargs.get("messages", None)) - else: + return litellm.stream_chunk_builder( + chunks, messages=kwargs.get("messages", None) + ) + else: return result - elif "acompletion" in kwargs and kwargs["acompletion"] == True: + elif "acompletion" in kwargs and kwargs["acompletion"] == True: return result - elif "aembedding" in kwargs and kwargs["aembedding"] == True: + elif "aembedding" in kwargs and kwargs["aembedding"] == True: return result - - ### POST-CALL RULES ### + + ### POST-CALL RULES ### post_call_processing(original_response=result, model=model or None) # [OPTIONAL] ADD TO CACHE - if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types: + if ( + litellm.cache is not None + and str(original_function.__name__) + in litellm.cache.supported_call_types + ): litellm.cache.add_cache(result, *args, **kwargs) # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated print_verbose(f"Wrapper: Completed Call, calling success_handler") - threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start() + threading.Thread( + target=logging_obj.success_handler, args=(result, start_time, end_time) + ).start() # RETURN RESULT - result._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai + result._response_ms = ( + end_time - start_time + ).total_seconds() * 1000 # return response latency in ms like openai return result except Exception as e: call_type = original_function.__name__ if call_type == CallTypes.completion.value: num_retries = ( - kwargs.get("num_retries", None) - or litellm.num_retries - or None + kwargs.get("num_retries", None) or litellm.num_retries or None + ) + litellm.num_retries = ( + None # set retries to None to prevent infinite loops + ) + context_window_fallback_dict = kwargs.get( + "context_window_fallback_dict", {} ) - litellm.num_retries = None # set retries to None to prevent infinite loops - context_window_fallback_dict = kwargs.get("context_window_fallback_dict", {}) - if num_retries: - if (isinstance(e, openai.APIError) - or isinstance(e, openai.Timeout)): + if num_retries: + if isinstance(e, openai.APIError) or isinstance(e, openai.Timeout): kwargs["num_retries"] = num_retries return litellm.completion_with_retries(*args, **kwargs) - elif isinstance(e, litellm.exceptions.ContextWindowExceededError) and context_window_fallback_dict and model in context_window_fallback_dict: + elif ( + isinstance(e, litellm.exceptions.ContextWindowExceededError) + and context_window_fallback_dict + and model in context_window_fallback_dict + ): if len(args) > 0: - args[0] = context_window_fallback_dict[model] + args[0] = context_window_fallback_dict[model] else: kwargs["model"] = context_window_fallback_dict[model] return original_function(*args, **kwargs) @@ -1705,7 +2028,9 @@ def client(original_function): end_time = datetime.datetime.now() # LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated if logging_obj: - logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # DO NOT MAKE THREADED - router retry fallback relies on this! + logging_obj.failure_handler( + e, traceback_exception, start_time, end_time + ) # DO NOT MAKE THREADED - router retry fallback relies on this! my_thread = threading.Thread( target=handle_failure, args=(e, traceback_exception, start_time, end_time, args, kwargs), @@ -1717,8 +2042,8 @@ def client(original_function): ): # make it easy to get to the debugger logs if you've initialized it e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}" raise e - - async def wrapper_async(*args, **kwargs): + + async def wrapper_async(*args, **kwargs): start_time = datetime.datetime.now() result = None logging_obj = kwargs.get("litellm_logging_obj", None) @@ -1729,115 +2054,215 @@ def client(original_function): model = args[0] if len(args) > 0 else kwargs["model"] except: raise ValueError("model param not passed in.") - - try: + + try: if logging_obj is None: logging_obj = function_setup(start_time, *args, **kwargs) kwargs["litellm_logging_obj"] = logging_obj - # [OPTIONAL] CHECK BUDGET + # [OPTIONAL] CHECK BUDGET if litellm.max_budget: if litellm._current_cost > litellm.max_budget: - raise BudgetExceededError(current_cost=litellm._current_cost, max_budget=litellm.max_budget) + raise BudgetExceededError( + current_cost=litellm._current_cost, + max_budget=litellm.max_budget, + ) # [OPTIONAL] CHECK CACHE print_verbose(f"litellm.cache: {litellm.cache}") - print_verbose(f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}") - # if caching is false, don't run this - if (kwargs.get("caching", None) is None and litellm.cache is not None) or kwargs.get("caching", False) == True: # allow users to control returning cached responses from the completion function + print_verbose( + f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}" + ) + # if caching is false, don't run this + if ( + kwargs.get("caching", None) is None and litellm.cache is not None + ) or kwargs.get( + "caching", False + ) == True: # allow users to control returning cached responses from the completion function # checking cache print_verbose(f"INSIDE CHECKING CACHE") - if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types: + if ( + litellm.cache is not None + and str(original_function.__name__) + in litellm.cache.supported_call_types + ): print_verbose(f"Checking Cache") cached_result = litellm.cache.get_cache(*args, **kwargs) if cached_result != None: print_verbose(f"Cache Hit!") call_type = original_function.__name__ - if call_type == CallTypes.acompletion.value and isinstance(cached_result, dict): + if call_type == CallTypes.acompletion.value and isinstance( + cached_result, dict + ): if kwargs.get("stream", False) == True: cached_result = convert_to_streaming_response_async( response_object=cached_result, ) else: - cached_result = convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse()) - elif call_type == CallTypes.aembedding.value and isinstance(cached_result, dict): - cached_result = convert_to_model_response_object(response_object=cached_result, model_response_object=EmbeddingResponse(), response_type="embedding") - # LOG SUCCESS + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=ModelResponse(), + ) + elif call_type == CallTypes.aembedding.value and isinstance( + cached_result, dict + ): + cached_result = convert_to_model_response_object( + response_object=cached_result, + model_response_object=EmbeddingResponse(), + response_type="embedding", + ) + # LOG SUCCESS cache_hit = True end_time = datetime.datetime.now() - model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(model=model, custom_llm_provider=kwargs.get('custom_llm_provider', None), api_base=kwargs.get('api_base', None), api_key=kwargs.get('api_key', None)) - print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}") - logging_obj.update_environment_variables(model=model, user=kwargs.get('user', None), optional_params={}, litellm_params={"logger_fn": kwargs.get('logger_fn', None), "acompletion": True, "metadata": kwargs.get("metadata", {}), "model_info": kwargs.get("model_info", {}), "proxy_server_request": kwargs.get("proxy_server_request", None), "preset_cache_key": kwargs.get("preset_cache_key", None), "stream_response": kwargs.get("stream_response", {})}, input=kwargs.get('messages', ""), api_key=kwargs.get('api_key', None), original_response=str(cached_result), additional_args=None, stream=kwargs.get('stream', False)) - asyncio.create_task(logging_obj.async_success_handler(cached_result, start_time, end_time, cache_hit)) - threading.Thread(target=logging_obj.success_handler, args=(cached_result, start_time, end_time, cache_hit)).start() + ( + model, + custom_llm_provider, + dynamic_api_key, + api_base, + ) = litellm.get_llm_provider( + model=model, + custom_llm_provider=kwargs.get("custom_llm_provider", None), + api_base=kwargs.get("api_base", None), + api_key=kwargs.get("api_key", None), + ) + print_verbose( + f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" + ) + logging_obj.update_environment_variables( + model=model, + user=kwargs.get("user", None), + optional_params={}, + litellm_params={ + "logger_fn": kwargs.get("logger_fn", None), + "acompletion": True, + "metadata": kwargs.get("metadata", {}), + "model_info": kwargs.get("model_info", {}), + "proxy_server_request": kwargs.get( + "proxy_server_request", None + ), + "preset_cache_key": kwargs.get( + "preset_cache_key", None + ), + "stream_response": kwargs.get("stream_response", {}), + }, + input=kwargs.get("messages", ""), + api_key=kwargs.get("api_key", None), + original_response=str(cached_result), + additional_args=None, + stream=kwargs.get("stream", False), + ) + asyncio.create_task( + logging_obj.async_success_handler( + cached_result, start_time, end_time, cache_hit + ) + ) + threading.Thread( + target=logging_obj.success_handler, + args=(cached_result, start_time, end_time, cache_hit), + ).start() return cached_result # MODEL CALL result = await original_function(*args, **kwargs) end_time = datetime.datetime.now() if "stream" in kwargs and kwargs["stream"] == True: - if "complete_response" in kwargs and kwargs["complete_response"] == True: + if ( + "complete_response" in kwargs + and kwargs["complete_response"] == True + ): chunks = [] for idx, chunk in enumerate(result): chunks.append(chunk) - return litellm.stream_chunk_builder(chunks, messages=kwargs.get("messages", None)) - else: + return litellm.stream_chunk_builder( + chunks, messages=kwargs.get("messages", None) + ) + else: return result - - ### POST-CALL RULES ### + + ### POST-CALL RULES ### post_call_processing(original_response=result, model=model) # [OPTIONAL] ADD TO CACHE - if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types: - if isinstance(result, litellm.ModelResponse) or isinstance(result, litellm.EmbeddingResponse): - asyncio.create_task(litellm.cache._async_add_cache(result.json(), *args, **kwargs)) + if ( + litellm.cache is not None + and str(original_function.__name__) + in litellm.cache.supported_call_types + ): + if isinstance(result, litellm.ModelResponse) or isinstance( + result, litellm.EmbeddingResponse + ): + asyncio.create_task( + litellm.cache._async_add_cache(result.json(), *args, **kwargs) + ) else: - asyncio.create_task(litellm.cache._async_add_cache(result, *args, **kwargs)) + asyncio.create_task( + litellm.cache._async_add_cache(result, *args, **kwargs) + ) # LOG SUCCESS - handle streaming success logging in the _next_ object - print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}") - asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time)) - threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start() + print_verbose( + f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" + ) + asyncio.create_task( + logging_obj.async_success_handler(result, start_time, end_time) + ) + threading.Thread( + target=logging_obj.success_handler, args=(result, start_time, end_time) + ).start() # RETURN RESULT if isinstance(result, ModelResponse): - result._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai + result._response_ms = ( + end_time - start_time + ).total_seconds() * 1000 # return response latency in ms like openai return result - except Exception as e: + except Exception as e: traceback_exception = traceback.format_exc() crash_reporting(*args, **kwargs, exception=traceback_exception) end_time = datetime.datetime.now() if logging_obj: try: - logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # DO NOT MAKE THREADED - router retry fallback relies on this! - except Exception as e: - raise e - try: - await logging_obj.async_failure_handler(e, traceback_exception, start_time, end_time) + logging_obj.failure_handler( + e, traceback_exception, start_time, end_time + ) # DO NOT MAKE THREADED - router retry fallback relies on this! except Exception as e: raise e - + try: + await logging_obj.async_failure_handler( + e, traceback_exception, start_time, end_time + ) + except Exception as e: + raise e + call_type = original_function.__name__ if call_type == CallTypes.acompletion.value: num_retries = ( - kwargs.get("num_retries", None) - or litellm.num_retries - or None + kwargs.get("num_retries", None) or litellm.num_retries or None ) - litellm.num_retries = None # set retries to None to prevent infinite loops - context_window_fallback_dict = kwargs.get("context_window_fallback_dict", {}) - - if num_retries: - try: + litellm.num_retries = ( + None # set retries to None to prevent infinite loops + ) + context_window_fallback_dict = kwargs.get( + "context_window_fallback_dict", {} + ) + + if num_retries: + try: kwargs["num_retries"] = num_retries kwargs["original_function"] = original_function - if (isinstance(e, openai.RateLimitError)): # rate limiting specific error + if isinstance( + e, openai.RateLimitError + ): # rate limiting specific error kwargs["retry_strategy"] = "exponential_backoff_retry" - elif (isinstance(e, openai.APIError)): # generic api error + elif isinstance(e, openai.APIError): # generic api error kwargs["retry_strategy"] = "constant_retry" return await litellm.acompletion_with_retries(*args, **kwargs) except: pass - elif isinstance(e, litellm.exceptions.ContextWindowExceededError) and context_window_fallback_dict and model in context_window_fallback_dict: + elif ( + isinstance(e, litellm.exceptions.ContextWindowExceededError) + and context_window_fallback_dict + and model in context_window_fallback_dict + ): if len(args) > 0: - args[0] = context_window_fallback_dict[model] + args[0] = context_window_fallback_dict[model] else: kwargs["model"] = context_window_fallback_dict[model] return await original_function(*args, **kwargs) @@ -1851,6 +2276,7 @@ def client(original_function): else: return wrapper + ####### USAGE CALCULATOR ################ @@ -1858,7 +2284,10 @@ def client(original_function): # only used for together_computer LLMs def get_model_params_and_category(model_name): import re - params_match = re.search(r'(\d+b)', model_name) # catch all decimals like 3b, 70b, etc + + params_match = re.search( + r"(\d+b)", model_name + ) # catch all decimals like 3b, 70b, etc category = None if params_match != None: params_match = params_match.group(1) @@ -1879,30 +2308,36 @@ def get_model_params_and_category(model_name): return None + def get_replicate_completion_pricing(completion_response=None, total_time=0.0): # see https://replicate.com/pricing a100_40gb_price_per_second_public = 0.001150 # for all litellm currently supported LLMs, almost all requests go to a100_80gb - a100_80gb_price_per_second_public = 0.001400 # assume all calls sent to A100 80GB for now + a100_80gb_price_per_second_public = ( + 0.001400 # assume all calls sent to A100 80GB for now + ) if total_time == 0.0: - start_time = completion_response['created'] + start_time = completion_response["created"] end_time = completion_response["ended"] total_time = end_time - start_time - return a100_80gb_price_per_second_public*total_time + return a100_80gb_price_per_second_public * total_time -def _select_tokenizer(model: str): - # cohere +def _select_tokenizer(model: str): + # cohere import pkg_resources + if model in litellm.cohere_models: tokenizer = Tokenizer.from_pretrained("Cohere/command-nightly") return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} - # anthropic + # anthropic elif model in litellm.anthropic_models: # Read the JSON file - filename = pkg_resources.resource_filename(__name__, 'llms/tokenizers/anthropic_tokenizer.json') - with open(filename, 'r') as f: + filename = pkg_resources.resource_filename( + __name__, "llms/tokenizers/anthropic_tokenizer.json" + ) + with open(filename, "r") as f: json_data = json.load(f) # Decode the JSON data from utf-8 json_data_decoded = json.dumps(json_data, ensure_ascii=False) @@ -1911,15 +2346,16 @@ def _select_tokenizer(model: str): # load tokenizer tokenizer = Tokenizer.from_str(json_str) return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} - # llama2 - elif "llama-2" in model.lower(): + # llama2 + elif "llama-2" in model.lower(): tokenizer = Tokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") return {"type": "huggingface_tokenizer", "tokenizer": tokenizer} # default - tiktoken - else: + else: return {"type": "openai_tokenizer", "tokenizer": encoding} -def encode(model: str, text: str): + +def encode(model: str, text: str): """ Encodes the given text using the specified model. @@ -1934,12 +2370,18 @@ def encode(model: str, text: str): enc = tokenizer_json["tokenizer"].encode(text) return enc -def decode(model: str, tokens: List[int]): + +def decode(model: str, tokens: List[int]): tokenizer_json = _select_tokenizer(model=model) dec = tokenizer_json["tokenizer"].decode(tokens) return dec -def openai_token_counter(messages: Optional[list]=None, model="gpt-3.5-turbo-0613", text: Optional[str]= None): + +def openai_token_counter( + messages: Optional[list] = None, + model="gpt-3.5-turbo-0613", + text: Optional[str] = None, +): """ Return the number of tokens used by a list of messages. @@ -1951,7 +2393,9 @@ def openai_token_counter(messages: Optional[list]=None, model="gpt-3.5-turbo-061 print_verbose("Warning: model not found. Using cl100k_base encoding.") encoding = tiktoken.get_encoding("cl100k_base") if model == "gpt-3.5-turbo-0301": - tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n + tokens_per_message = ( + 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n + ) tokens_per_name = -1 # if there's a name, the role is omitted elif model in litellm.open_ai_chat_completion_models: tokens_per_message = 3 @@ -1962,9 +2406,9 @@ def openai_token_counter(messages: Optional[list]=None, model="gpt-3.5-turbo-061 ) num_tokens = 0 - if text: + if text: num_tokens = len(encoding.encode(text, disallowed_special=())) - elif messages: + elif messages: for message in messages: num_tokens += tokens_per_message for key, value in message.items(): @@ -1974,7 +2418,8 @@ def openai_token_counter(messages: Optional[list]=None, model="gpt-3.5-turbo-061 num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> return num_tokens -def token_counter(model="", text=None, messages: Optional[List] = None): + +def token_counter(model="", text=None, messages: Optional[List] = None): """ Count the number of tokens in a given text using a specified model. @@ -1990,24 +2435,24 @@ def token_counter(model="", text=None, messages: Optional[List] = None): if text == None: if messages is not None: print_verbose(f"token_counter messages received: {messages}") - text = "" - for message in messages: + text = "" + for message in messages: if message.get("content", None): text += message["content"] - if 'tool_calls' in message: - for tool_call in message['tool_calls']: - if 'function' in tool_call: - function_arguments = tool_call['function']['arguments'] + if "tool_calls" in message: + for tool_call in message["tool_calls"]: + if "function" in tool_call: + function_arguments = tool_call["function"]["arguments"] text += function_arguments else: raise ValueError("text and messages cannot both be None") num_tokens = 0 if model is not None: tokenizer_json = _select_tokenizer(model=model) - if tokenizer_json["type"] == "huggingface_tokenizer": + if tokenizer_json["type"] == "huggingface_tokenizer": enc = tokenizer_json["tokenizer"].encode(text) num_tokens = len(enc.ids) - elif tokenizer_json["type"] == "openai_tokenizer": + elif tokenizer_json["type"] == "openai_tokenizer": if model in litellm.open_ai_chat_completion_models: num_tokens = openai_token_counter(text=text, model=model) else: @@ -2026,7 +2471,7 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0): model (str): The name of the model to use. Default is "" prompt_tokens (int): The number of tokens in the prompt. completion_tokens (int): The number of tokens in the completion. - + Returns: tuple: A tuple containing the cost in USD dollars for prompt tokens and completion tokens, respectively. """ @@ -2038,7 +2483,7 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0): azure_llms = { "gpt-35-turbo": "azure/gpt-3.5-turbo", "gpt-35-turbo-16k": "azure/gpt-3.5-turbo-16k", - "gpt-35-turbo-instruct": "azure/gpt-3.5-turbo-instruct" + "gpt-35-turbo-instruct": "azure/gpt-3.5-turbo-instruct", } if model in model_cost_ref: prompt_tokens_cost_usd_dollar = ( @@ -2054,7 +2499,8 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0): model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens ) completion_tokens_cost_usd_dollar = ( - model_cost_ref["ft:gpt-3.5-turbo"]["output_cost_per_token"] * completion_tokens + model_cost_ref["ft:gpt-3.5-turbo"]["output_cost_per_token"] + * completion_tokens ) return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar elif model in azure_llms: @@ -2082,22 +2528,22 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0): def completion_cost( - completion_response=None, - model=None, - prompt="", - messages: List = [], - completion="", - total_time=0.0, # used for replicate - ): + completion_response=None, + model=None, + prompt="", + messages: List = [], + completion="", + total_time=0.0, # used for replicate +): """ Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm. Parameters: completion_response (litellm.ModelResponses): [Required] The response received from a LiteLLM completion request. - + [OPTIONAL PARAMS] model (str): Optional. The name of the language model used in the completion calls - prompt (str): Optional. The input prompt passed to the llm + prompt (str): Optional. The input prompt passed to the llm completion (str): Optional. The output completion text from the llm total_time (float): Optional. (Only used for Replicate LLMs) The total time used for the request in seconds @@ -2120,41 +2566,46 @@ def completion_cost( completion_tokens = 0 if completion_response is not None: # get input/output tokens from completion_response - prompt_tokens = completion_response['usage']['prompt_tokens'] - completion_tokens = completion_response['usage']['completion_tokens'] - model = model or completion_response['model'] # check if user passed an override for model, if it's none check completion_response['model'] + prompt_tokens = completion_response["usage"]["prompt_tokens"] + completion_tokens = completion_response["usage"]["completion_tokens"] + model = ( + model or completion_response["model"] + ) # check if user passed an override for model, if it's none check completion_response['model'] else: if len(messages) > 0: prompt_tokens = token_counter(model=model, messages=messages) - elif len(prompt) > 0: + elif len(prompt) > 0: prompt_tokens = token_counter(model=model, text=prompt) completion_tokens = token_counter(model=model, text=completion) - + # Calculate cost based on prompt_tokens, completion_tokens if "togethercomputer" in model: # together ai prices based on size of llm - # get_model_params_and_category takes a model name and returns the category of LLM size it is in model_prices_and_context_window.json + # get_model_params_and_category takes a model name and returns the category of LLM size it is in model_prices_and_context_window.json model = get_model_params_and_category(model) # replicate llms are calculate based on time for request running # see https://replicate.com/pricing - elif ( - model in litellm.replicate_models or - "replicate" in model - ): + elif model in litellm.replicate_models or "replicate" in model: return get_replicate_completion_pricing(completion_response, total_time) - prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = cost_per_token( - model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ( + prompt_tokens_cost_usd_dollar, + completion_tokens_cost_usd_dollar, + ) = cost_per_token( + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, ) return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar except: - return 0.0 # this should not block a users execution path + return 0.0 # this should not block a users execution path + ####### HELPER FUNCTIONS ################ -def register_model(model_cost: Union[str, dict]): +def register_model(model_cost: Union[str, dict]): """ - Register new / Override existing models (and their pricing) to specific providers. + Register new / Override existing models (and their pricing) to specific providers. Provide EITHER a model cost dictionary or a url to a hosted json blob - Example usage: + Example usage: model_cost_dict = { "gpt-4": { "max_tokens": 8192, @@ -2166,59 +2617,60 @@ def register_model(model_cost: Union[str, dict]): } """ loaded_model_cost = {} - if isinstance(model_cost, dict): + if isinstance(model_cost, dict): loaded_model_cost = model_cost - elif isinstance(model_cost, str): + elif isinstance(model_cost, str): loaded_model_cost = litellm.get_model_cost_map(url=model_cost) for key, value in loaded_model_cost.items(): ## override / add new keys to the existing model cost dictionary if key in litellm.model_cost: - for k,v in loaded_model_cost[key].items(): + for k, v in loaded_model_cost[key].items(): litellm.model_cost[key][k] = v # add new model names to provider lists - if value.get('litellm_provider') == 'openai': + if value.get("litellm_provider") == "openai": if key not in litellm.open_ai_chat_completion_models: litellm.open_ai_chat_completion_models.append(key) - elif value.get('litellm_provider') == 'text-completion-openai': + elif value.get("litellm_provider") == "text-completion-openai": if key not in litellm.open_ai_text_completion_models: litellm.open_ai_text_completion_models.append(key) - elif value.get('litellm_provider') == 'cohere': + elif value.get("litellm_provider") == "cohere": if key not in litellm.cohere_models: litellm.cohere_models.append(key) - elif value.get('litellm_provider') == 'anthropic': + elif value.get("litellm_provider") == "anthropic": if key not in litellm.anthropic_models: litellm.anthropic_models.append(key) - elif value.get('litellm_provider') == 'openrouter': - split_string = key.split('/', 1) + elif value.get("litellm_provider") == "openrouter": + split_string = key.split("/", 1) if key not in litellm.openrouter_models: litellm.openrouter_models.append(split_string[1]) - elif value.get('litellm_provider') == 'vertex_ai-text-models': + elif value.get("litellm_provider") == "vertex_ai-text-models": if key not in litellm.vertex_text_models: litellm.vertex_text_models.append(key) - elif value.get('litellm_provider') == 'vertex_ai-code-text-models': + elif value.get("litellm_provider") == "vertex_ai-code-text-models": if key not in litellm.vertex_code_text_models: litellm.vertex_code_text_models.append(key) - elif value.get('litellm_provider') == 'vertex_ai-chat-models': + elif value.get("litellm_provider") == "vertex_ai-chat-models": if key not in litellm.vertex_chat_models: litellm.vertex_chat_models.append(key) - elif value.get('litellm_provider') == 'vertex_ai-code-chat-models': + elif value.get("litellm_provider") == "vertex_ai-code-chat-models": if key not in litellm.vertex_code_chat_models: litellm.vertex_code_chat_models.append(key) - elif value.get('litellm_provider') == 'ai21': + elif value.get("litellm_provider") == "ai21": if key not in litellm.ai21_models: litellm.ai21_models.append(key) - elif value.get('litellm_provider') == 'nlp_cloud': + elif value.get("litellm_provider") == "nlp_cloud": if key not in litellm.nlp_cloud_models: litellm.nlp_cloud_models.append(key) - elif value.get('litellm_provider') == 'aleph_alpha': + elif value.get("litellm_provider") == "aleph_alpha": if key not in litellm.aleph_alpha_models: litellm.aleph_alpha_models.append(key) - elif value.get('litellm_provider') == 'bedrock': + elif value.get("litellm_provider") == "bedrock": if key not in litellm.bedrock_models: litellm.bedrock_models.append(key) return model_cost + def get_litellm_params( api_key=None, force_timeout=600, @@ -2237,7 +2689,7 @@ def get_litellm_params( model_info=None, proxy_server_request=None, acompletion=None, - preset_cache_key = None + preset_cache_key=None, ): litellm_params = { "acompletion": acompletion, @@ -2254,20 +2706,21 @@ def get_litellm_params( "model_info": model_info, "proxy_server_request": proxy_server_request, "preset_cache_key": preset_cache_key, - "stream_response": {} # litellm_call_id: ModelResponse Dict + "stream_response": {}, # litellm_call_id: ModelResponse Dict } return litellm_params + def get_optional_params_image_gen( - n: Optional[int]=None, - quality: Optional[str]=None, - response_format: Optional[str]=None, - size: Optional[str]=None, - style: Optional[str]=None, - user: Optional[str]=None, - custom_llm_provider: Optional[str]=None, - **kwargs + n: Optional[int] = None, + quality: Optional[str] = None, + response_format: Optional[str] = None, + size: Optional[str] = None, + style: Optional[str] = None, + user: Optional[str] = None, + custom_llm_provider: Optional[str] = None, + **kwargs, ): # retrieve all parameters passed to the function passed_params = locals() @@ -2275,38 +2728,44 @@ def get_optional_params_image_gen( special_params = passed_params.pop("kwargs") for k, v in special_params.items(): passed_params[k] = v - + default_params = { - "n": None, - "quality" : None, - "response_format" : None, - "size": None, + "n": None, + "quality": None, + "response_format": None, + "size": None, "style": None, "user": None, } - non_default_params = {k: v for k, v in passed_params.items() if (k in default_params and v != default_params[k])} + non_default_params = { + k: v + for k, v in passed_params.items() + if (k in default_params and v != default_params[k]) + } ## raise exception if non-default value passed for non-openai/azure embedding calls if custom_llm_provider != "openai" and custom_llm_provider != "azure": - if len(non_default_params.keys()) > 0: - if litellm.drop_params is True: # drop the unsupported non-default values + if len(non_default_params.keys()) > 0: + if litellm.drop_params is True: # drop the unsupported non-default values keys = list(non_default_params.keys()) - for k in keys: + for k in keys: non_default_params.pop(k, None) return non_default_params - raise UnsupportedParamsError(status_code=500, message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.") - + raise UnsupportedParamsError( + status_code=500, + message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.", + ) + final_params = {**non_default_params, **kwargs} return final_params - - + def get_optional_params_embeddings( # 2 optional params - user=None, + user=None, encoding_format=None, custom_llm_provider="", - **kwargs + **kwargs, ): # retrieve all parameters passed to the function passed_params = locals() @@ -2314,26 +2773,31 @@ def get_optional_params_embeddings( special_params = passed_params.pop("kwargs") for k, v in special_params.items(): passed_params[k] = v - - default_params = { - "user": None, - "encoding_format": None - } - non_default_params = {k: v for k, v in passed_params.items() if (k in default_params and v != default_params[k])} + default_params = {"user": None, "encoding_format": None} + + non_default_params = { + k: v + for k, v in passed_params.items() + if (k in default_params and v != default_params[k]) + } ## raise exception if non-default value passed for non-openai/azure embedding calls if custom_llm_provider != "openai" and custom_llm_provider != "azure": - if len(non_default_params.keys()) > 0: - if litellm.drop_params is True: # drop the unsupported non-default values + if len(non_default_params.keys()) > 0: + if litellm.drop_params is True: # drop the unsupported non-default values keys = list(non_default_params.keys()) - for k in keys: + for k in keys: non_default_params.pop(k, None) return non_default_params - raise UnsupportedParamsError(status_code=500, message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.") - + raise UnsupportedParamsError( + status_code=500, + message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.", + ) + final_params = {**non_default_params, **kwargs} return final_params + def get_optional_params( # use the openai defaults # 12 optional params functions=None, @@ -2355,7 +2819,7 @@ def get_optional_params( # use the openai defaults tools=None, tool_choice=None, max_retries=None, - **kwargs + **kwargs, ): # retrieve all parameters passed to the function passed_params = locals() @@ -2365,18 +2829,18 @@ def get_optional_params( # use the openai defaults default_params = { "functions": None, "function_call": None, - "temperature":None, - "top_p":None, - "n":None, - "stream":None, - "stop":None, - "max_tokens":None, - "presence_penalty":None, - "frequency_penalty":None, + "temperature": None, + "top_p": None, + "n": None, + "stream": None, + "stop": None, + "max_tokens": None, + "presence_penalty": None, + "frequency_penalty": None, "logit_bias": None, - "user":None, - "model":None, - "custom_llm_provider":"", + "user": None, + "model": None, + "custom_llm_provider": "", "response_format": None, "seed": None, "tools": None, @@ -2384,53 +2848,82 @@ def get_optional_params( # use the openai defaults "max_retries": None, } # filter out those parameters that were passed with non-default values - non_default_params = {k: v for k, v in passed_params.items() if (k != "model" and k != "custom_llm_provider" and k in default_params and v != default_params[k])} + non_default_params = { + k: v + for k, v in passed_params.items() + if ( + k != "model" + and k != "custom_llm_provider" + and k in default_params + and v != default_params[k] + ) + } optional_params = {} ## raise exception if function calling passed in for a provider that doesn't support it if "functions" in non_default_params or "function_call" in non_default_params: - if custom_llm_provider != "openai" and custom_llm_provider != "text-completion-openai" and custom_llm_provider != "azure": - if litellm.add_function_to_prompt: # if user opts to add it to prompt instead - optional_params["functions_unsupported_model"] = non_default_params.pop("functions") - else: - raise UnsupportedParamsError(status_code=500, message=f"Function calling is not supported by {custom_llm_provider}. To add it to the prompt, set `litellm.add_function_to_prompt = True`.") + if ( + custom_llm_provider != "openai" + and custom_llm_provider != "text-completion-openai" + and custom_llm_provider != "azure" + ): + if ( + litellm.add_function_to_prompt + ): # if user opts to add it to prompt instead + optional_params["functions_unsupported_model"] = non_default_params.pop( + "functions" + ) + else: + raise UnsupportedParamsError( + status_code=500, + message=f"Function calling is not supported by {custom_llm_provider}. To add it to the prompt, set `litellm.add_function_to_prompt = True`.", + ) - def _check_valid_arg(supported_params): - print_verbose(f"\nLiteLLM completion() model= {model}; provider = {custom_llm_provider}") + def _check_valid_arg(supported_params): + print_verbose( + f"\nLiteLLM completion() model= {model}; provider = {custom_llm_provider}" + ) print_verbose(f"\nLiteLLM: Params passed to completion() {passed_params}") - print_verbose(f"\nLiteLLM: Non-Default params passed to completion() {non_default_params}") + print_verbose( + f"\nLiteLLM: Non-Default params passed to completion() {non_default_params}" + ) unsupported_params = {} for k in non_default_params.keys(): if k not in supported_params: - if k == "n" and n == 1: # langchain sends n=1 as a default value - continue # skip this param - if k == "max_retries": # TODO: This is a patch. We support max retries for OpenAI, Azure. For non OpenAI LLMs we need to add support for max retries - continue # skip this param + if k == "n" and n == 1: # langchain sends n=1 as a default value + continue # skip this param + if ( + k == "max_retries" + ): # TODO: This is a patch. We support max retries for OpenAI, Azure. For non OpenAI LLMs we need to add support for max retries + continue # skip this param # Always keeps this in elif code blocks - else: + else: unsupported_params[k] = non_default_params[k] if unsupported_params and not litellm.drop_params: - raise UnsupportedParamsError(status_code=500, message=f"{custom_llm_provider} does not support parameters: {unsupported_params}. To drop these, set `litellm.drop_params=True`.") - + raise UnsupportedParamsError( + status_code=500, + message=f"{custom_llm_provider} does not support parameters: {unsupported_params}. To drop these, set `litellm.drop_params=True`.", + ) + def _map_and_modify_arg(supported_params: dict, provider: str, model: str): """ filter params to fit the required provider format, drop those that don't fit if user sets `litellm.drop_params = True`. """ filtered_stop = None - if "stop" in supported_params and litellm.drop_params: - if provider == "bedrock" and "amazon" in model: + if "stop" in supported_params and litellm.drop_params: + if provider == "bedrock" and "amazon" in model: filtered_stop = [] - if isinstance(stop, list): - for s in stop: - if re.match(r'^(\|+|User:)$', s): - filtered_stop.append(s) - if filtered_stop is not None: + if isinstance(stop, list): + for s in stop: + if re.match(r"^(\|+|User:)$", s): + filtered_stop.append(s) + if filtered_stop is not None: supported_params["stop"] = filtered_stop return supported_params - ## raise exception if provider doesn't support passed in param + ## raise exception if provider doesn't support passed in param if custom_llm_provider == "anthropic": - ## check if unsupported param passed in + ## check if unsupported param passed in supported_params = ["stream", "stop", "temperature", "top_p", "max_tokens"] _check_valid_arg(supported_params=supported_params) # handle anthropic params @@ -2438,7 +2931,7 @@ def get_optional_params( # use the openai defaults optional_params["stream"] = stream if stop is not None: if type(stop) == str: - stop = [stop] # openai can accept str/list for stop + stop = [stop] # openai can accept str/list for stop optional_params["stop_sequences"] = stop if temperature is not None: optional_params["temperature"] = temperature @@ -2447,8 +2940,18 @@ def get_optional_params( # use the openai defaults if max_tokens is not None: optional_params["max_tokens_to_sample"] = max_tokens elif custom_llm_provider == "cohere": - ## check if unsupported param passed in - supported_params = ["stream", "temperature", "max_tokens", "logit_bias", "top_p", "frequency_penalty", "presence_penalty", "stop", "n"] + ## check if unsupported param passed in + supported_params = [ + "stream", + "temperature", + "max_tokens", + "logit_bias", + "top_p", + "frequency_penalty", + "presence_penalty", + "stop", + "n", + ] _check_valid_arg(supported_params=supported_params) # handle cohere params if stream: @@ -2470,8 +2973,15 @@ def get_optional_params( # use the openai defaults if stop is not None: optional_params["stop_sequences"] = stop elif custom_llm_provider == "maritalk": - ## check if unsupported param passed in - supported_params = ["stream", "temperature", "max_tokens", "top_p", "presence_penalty", "stop"] + ## check if unsupported param passed in + supported_params = [ + "stream", + "temperature", + "max_tokens", + "top_p", + "presence_penalty", + "stop", + ] _check_valid_arg(supported_params=supported_params) # handle cohere params if stream: @@ -2489,17 +2999,24 @@ def get_optional_params( # use the openai defaults if stop is not None: optional_params["stopping_tokens"] = stop elif custom_llm_provider == "replicate": - ## check if unsupported param passed in - supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "seed"] + ## check if unsupported param passed in + supported_params = [ + "stream", + "temperature", + "max_tokens", + "top_p", + "stop", + "seed", + ] _check_valid_arg(supported_params=supported_params) - + if stream: optional_params["stream"] = stream return optional_params if max_tokens is not None: if "vicuna" in model or "flan" in model: optional_params["max_length"] = max_tokens - elif "meta/codellama-13b" in model: + elif "meta/codellama-13b" in model: optional_params["max_tokens"] = max_tokens else: optional_params["max_new_tokens"] = max_tokens @@ -2510,7 +3027,7 @@ def get_optional_params( # use the openai defaults if stop is not None: optional_params["stop_sequences"] = stop elif custom_llm_provider == "huggingface": - ## check if unsupported param passed in + ## check if unsupported param passed in supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] _check_valid_arg(supported_params=supported_params) # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None @@ -2524,7 +3041,9 @@ def get_optional_params( # use the openai defaults optional_params["top_p"] = top_p if n is not None: optional_params["best_of"] = n - optional_params["do_sample"] = True # Need to sample if you want best of for hf inference endpoints + optional_params[ + "do_sample" + ] = True # Need to sample if you want best of for hf inference endpoints if stream is not None: optional_params["stream"] = stream if stop is not None: @@ -2535,7 +3054,7 @@ def get_optional_params( # use the openai defaults if max_tokens == 0: max_tokens = 1 optional_params["max_new_tokens"] = max_tokens - if n is not None: + if n is not None: optional_params["best_of"] = n if presence_penalty is not None: optional_params["repetition_penalty"] = presence_penalty @@ -2543,12 +3062,21 @@ def get_optional_params( # use the openai defaults # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details # Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False optional_params["decoder_input_details"] = special_params["echo"] - passed_params.pop("echo", None) # since we handle translating echo, we should not send it to TGI request + passed_params.pop( + "echo", None + ) # since we handle translating echo, we should not send it to TGI request elif custom_llm_provider == "together_ai": - ## check if unsupported param passed in - supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "frequency_penalty"] + ## check if unsupported param passed in + supported_params = [ + "stream", + "temperature", + "max_tokens", + "top_p", + "stop", + "frequency_penalty", + ] _check_valid_arg(supported_params=supported_params) - + if stream: optional_params["stream_tokens"] = stream if temperature is not None: @@ -2558,12 +3086,23 @@ def get_optional_params( # use the openai defaults if max_tokens is not None: optional_params["max_tokens"] = max_tokens if frequency_penalty is not None: - optional_params["repetition_penalty"] = frequency_penalty # https://docs.together.ai/reference/inference + optional_params[ + "repetition_penalty" + ] = frequency_penalty # https://docs.together.ai/reference/inference if stop is not None: - optional_params["stop"] = stop + optional_params["stop"] = stop elif custom_llm_provider == "ai21": - ## check if unsupported param passed in - supported_params = ["stream", "n", "temperature", "max_tokens", "top_p", "stop", "frequency_penalty", "presence_penalty"] + ## check if unsupported param passed in + supported_params = [ + "stream", + "n", + "temperature", + "max_tokens", + "top_p", + "stop", + "frequency_penalty", + "presence_penalty", + ] _check_valid_arg(supported_params=supported_params) if stream: @@ -2582,11 +3121,13 @@ def get_optional_params( # use the openai defaults optional_params["frequencyPenalty"] = {"scale": frequency_penalty} if presence_penalty is not None: optional_params["presencePenalty"] = {"scale": presence_penalty} - elif custom_llm_provider == "palm": # https://developers.generativeai.google/tutorials/curl_quickstart - ## check if unsupported param passed in + elif ( + custom_llm_provider == "palm" + ): # https://developers.generativeai.google/tutorials/curl_quickstart + ## check if unsupported param passed in supported_params = ["temperature", "top_p", "stream", "n", "stop", "max_tokens"] _check_valid_arg(supported_params=supported_params) - + if temperature is not None: optional_params["temperature"] = temperature if top_p is not None: @@ -2599,13 +3140,11 @@ def get_optional_params( # use the openai defaults optional_params["stop_sequences"] = stop if max_tokens is not None: optional_params["max_output_tokens"] = max_tokens - elif ( - custom_llm_provider == "vertex_ai" - ): - ## check if unsupported param passed in + elif custom_llm_provider == "vertex_ai": + ## check if unsupported param passed in supported_params = ["temperature", "top_p", "max_tokens", "stream"] _check_valid_arg(supported_params=supported_params) - + if temperature is not None: optional_params["temperature"] = temperature if top_p is not None: @@ -2615,7 +3154,7 @@ def get_optional_params( # use the openai defaults if max_tokens is not None: optional_params["max_output_tokens"] = max_tokens elif custom_llm_provider == "sagemaker": - ## check if unsupported param passed in + ## check if unsupported param passed in supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] _check_valid_arg(supported_params=supported_params) # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None @@ -2629,7 +3168,9 @@ def get_optional_params( # use the openai defaults optional_params["top_p"] = top_p if n is not None: optional_params["best_of"] = n - optional_params["do_sample"] = True # Need to sample if you want best of for hf inference endpoints + optional_params[ + "do_sample" + ] = True # Need to sample if you want best of for hf inference endpoints if stream is not None: optional_params["stream"] = stream if stop is not None: @@ -2652,7 +3193,7 @@ def get_optional_params( # use the openai defaults optional_params["temperature"] = temperature if top_p is not None: optional_params["topP"] = top_p - if stream: + if stream: optional_params["stream"] = stream elif "anthropic" in model: supported_params = ["max_tokens", "temperature", "stop", "top_p", "stream"] @@ -2667,9 +3208,9 @@ def get_optional_params( # use the openai defaults optional_params["top_p"] = top_p if stop is not None: optional_params["stop_sequences"] = stop - if stream: + if stream: optional_params["stream"] = stream - elif "amazon" in model: # amazon titan llms + elif "amazon" in model: # amazon titan llms supported_params = ["max_tokens", "temperature", "stop", "top_p", "stream"] _check_valid_arg(supported_params=supported_params) # see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large @@ -2678,13 +3219,15 @@ def get_optional_params( # use the openai defaults if temperature is not None: optional_params["temperature"] = temperature if stop is not None: - filtered_stop = _map_and_modify_arg({"stop": stop}, provider="bedrock", model=model) + filtered_stop = _map_and_modify_arg( + {"stop": stop}, provider="bedrock", model=model + ) optional_params["stopSequences"] = filtered_stop["stop"] if top_p is not None: optional_params["topP"] = top_p - if stream: + if stream: optional_params["stream"] = stream - elif "meta" in model: # amazon / meta llms + elif "meta" in model: # amazon / meta llms supported_params = ["max_tokens", "temperature", "top_p", "stream"] _check_valid_arg(supported_params=supported_params) # see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large @@ -2694,9 +3237,9 @@ def get_optional_params( # use the openai defaults optional_params["temperature"] = temperature if top_p is not None: optional_params["top_p"] = top_p - if stream: + if stream: optional_params["stream"] = stream - elif "cohere" in model: # cohere models on bedrock + elif "cohere" in model: # cohere models on bedrock supported_params = ["stream", "temperature", "max_tokens"] _check_valid_arg(supported_params=supported_params) # handle cohere params @@ -2707,7 +3250,16 @@ def get_optional_params( # use the openai defaults if max_tokens is not None: optional_params["max_tokens"] = max_tokens elif custom_llm_provider == "aleph_alpha": - supported_params = ["max_tokens", "stream", "top_p", "temperature", "presence_penalty", "frequency_penalty", "n", "stop"] + supported_params = [ + "max_tokens", + "stream", + "top_p", + "temperature", + "presence_penalty", + "frequency_penalty", + "n", + "stop", + ] _check_valid_arg(supported_params=supported_params) if max_tokens is not None: optional_params["maximum_tokens"] = max_tokens @@ -2726,9 +3278,16 @@ def get_optional_params( # use the openai defaults if stop is not None: optional_params["stop_sequences"] = stop elif custom_llm_provider == "ollama": - supported_params = ["max_tokens", "stream", "top_p", "temperature", "frequency_penalty", "stop"] + supported_params = [ + "max_tokens", + "stream", + "top_p", + "temperature", + "frequency_penalty", + "stop", + ] _check_valid_arg(supported_params=supported_params) - + if max_tokens is not None: optional_params["num_predict"] = max_tokens if stream: @@ -2742,7 +3301,16 @@ def get_optional_params( # use the openai defaults if stop is not None: optional_params["stop_sequences"] = stop elif custom_llm_provider == "nlp_cloud": - supported_params = ["max_tokens", "stream", "temperature", "top_p", "presence_penalty", "frequency_penalty", "n", "stop"] + supported_params = [ + "max_tokens", + "stream", + "temperature", + "top_p", + "presence_penalty", + "frequency_penalty", + "n", + "stop", + ] _check_valid_arg(supported_params=supported_params) if max_tokens is not None: @@ -2774,60 +3342,84 @@ def get_optional_params( # use the openai defaults if stream: optional_params["stream"] = stream elif custom_llm_provider == "deepinfra": - supported_params = ["temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user"] + supported_params = [ + "temperature", + "top_p", + "n", + "stream", + "stop", + "max_tokens", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + ] _check_valid_arg(supported_params=supported_params) if temperature is not None: - if temperature == 0 and model == "mistralai/Mistral-7B-Instruct-v0.1": # this model does no support temperature == 0 - temperature = 0.0001 # close to 0 + if ( + temperature == 0 and model == "mistralai/Mistral-7B-Instruct-v0.1" + ): # this model does no support temperature == 0 + temperature = 0.0001 # close to 0 optional_params["temperature"] = temperature if top_p: optional_params["top_p"] = top_p - if n: + if n: optional_params["n"] = n - if stream: + if stream: optional_params["stream"] = stream - if stop: + if stop: optional_params["stop"] = stop - if max_tokens: + if max_tokens: optional_params["max_tokens"] = max_tokens - if presence_penalty: + if presence_penalty: optional_params["presence_penalty"] = presence_penalty - if frequency_penalty: + if frequency_penalty: optional_params["frequency_penalty"] = frequency_penalty - if logit_bias: + if logit_bias: optional_params["logit_bias"] = logit_bias - if user: + if user: optional_params["user"] = user elif custom_llm_provider == "perplexity": - supported_params = ["temperature", "top_p", "stream", "max_tokens", "presence_penalty", "frequency_penalty"] + supported_params = [ + "temperature", + "top_p", + "stream", + "max_tokens", + "presence_penalty", + "frequency_penalty", + ] _check_valid_arg(supported_params=supported_params) if temperature is not None: - if temperature == 0 and model == "mistral-7b-instruct": # this model does no support temperature == 0 - temperature = 0.0001 # close to 0 + if ( + temperature == 0 and model == "mistral-7b-instruct" + ): # this model does no support temperature == 0 + temperature = 0.0001 # close to 0 optional_params["temperature"] = temperature - if top_p: + if top_p: optional_params["top_p"] = top_p - if stream: + if stream: optional_params["stream"] = stream - if max_tokens: + if max_tokens: optional_params["max_tokens"] = max_tokens - if presence_penalty: + if presence_penalty: optional_params["presence_penalty"] = presence_penalty - if frequency_penalty: + if frequency_penalty: optional_params["frequency_penalty"] = frequency_penalty elif custom_llm_provider == "anyscale": supported_params = ["temperature", "top_p", "stream", "max_tokens"] _check_valid_arg(supported_params=supported_params) optional_params = non_default_params if temperature is not None: - if temperature == 0 and model == "mistralai/Mistral-7B-Instruct-v0.1": # this model does no support temperature == 0 - temperature = 0.0001 # close to 0 + if ( + temperature == 0 and model == "mistralai/Mistral-7B-Instruct-v0.1" + ): # this model does no support temperature == 0 + temperature = 0.0001 # close to 0 optional_params["temperature"] = temperature - if top_p: + if top_p: optional_params["top_p"] = top_p - if stream: + if stream: optional_params["stream"] = stream - if max_tokens: + if max_tokens: optional_params["max_tokens"] = max_tokens elif custom_llm_provider == "mistral": supported_params = ["temperature", "top_p", "stream", "max_tokens"] @@ -2835,13 +3427,13 @@ def get_optional_params( # use the openai defaults optional_params = non_default_params if temperature is not None: optional_params["temperature"] = temperature - if top_p is not None: + if top_p is not None: optional_params["top_p"] = top_p - if stream is not None: + if stream is not None: optional_params["stream"] = stream - if max_tokens is not None: + if max_tokens is not None: optional_params["max_tokens"] = max_tokens - + # check safe_mode, random_seed: https://docs.mistral.ai/api/#operation/createChatCompletion safe_mode = passed_params.pop("safe_mode", None) random_seed = passed_params.pop("random_seed", None) @@ -2850,9 +3442,29 @@ def get_optional_params( # use the openai defaults extra_body["safe_mode"] = safe_mode if random_seed is not None: extra_body["random_seed"] = random_seed - optional_params["extra_body"] = extra_body # openai client supports `extra_body` param + optional_params[ + "extra_body" + ] = extra_body # openai client supports `extra_body` param elif custom_llm_provider == "openrouter": - supported_params = ["functions", "function_call", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice", "max_retries"] + supported_params = [ + "functions", + "function_call", + "temperature", + "top_p", + "n", + "stream", + "stop", + "max_tokens", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + "response_format", + "seed", + "tools", + "tool_choice", + "max_retries", + ] _check_valid_arg(supported_params=supported_params) if functions is not None: @@ -2889,7 +3501,7 @@ def get_optional_params( # use the openai defaults optional_params["tool_choice"] = tool_choice if max_retries is not None: optional_params["max_retries"] = max_retries - + # OpenRouter-only parameters extra_body = {} transforms = passed_params.pop("transforms", None) @@ -2901,9 +3513,29 @@ def get_optional_params( # use the openai defaults extra_body["models"] = models if route is not None: extra_body["route"] = route - optional_params["extra_body"] = extra_body # openai client supports `extra_body` param + optional_params[ + "extra_body" + ] = extra_body # openai client supports `extra_body` param else: # assume passing in params for openai/azure openai - supported_params = ["functions", "function_call", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice", "max_retries"] + supported_params = [ + "functions", + "function_call", + "temperature", + "top_p", + "n", + "stream", + "stop", + "max_tokens", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + "response_format", + "seed", + "tools", + "tool_choice", + "max_retries", + ] _check_valid_arg(supported_params=supported_params) if functions is not None: optional_params["functions"] = functions @@ -2940,35 +3572,44 @@ def get_optional_params( # use the openai defaults if max_retries is not None: optional_params["max_retries"] = max_retries optional_params = non_default_params - # if user passed in non-default kwargs for specific providers/models, pass them along - for k in passed_params.keys(): - if k not in default_params.keys(): + # if user passed in non-default kwargs for specific providers/models, pass them along + for k in passed_params.keys(): + if k not in default_params.keys(): optional_params[k] = passed_params[k] return optional_params -def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_base: Optional[str] = None, api_key: Optional[str] = None): + +def get_llm_provider( + model: str, + custom_llm_provider: Optional[str] = None, + api_base: Optional[str] = None, + api_key: Optional[str] = None, +): try: dynamic_api_key = None # check if llm provider provided - + if custom_llm_provider: return model, custom_llm_provider, dynamic_api_key, api_base - - if api_key and api_key.startswith("os.environ/"): + + if api_key and api_key.startswith("os.environ/"): dynamic_api_key = get_secret(api_key) # check if llm provider part of model name - if model.split("/",1)[0] in litellm.provider_list and model.split("/",1)[0] not in litellm.model_list: + if ( + model.split("/", 1)[0] in litellm.provider_list + and model.split("/", 1)[0] not in litellm.model_list + ): custom_llm_provider = model.split("/", 1)[0] model = model.split("/", 1)[1] if custom_llm_provider == "perplexity": # perplexity is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.perplexity.ai api_base = "https://api.perplexity.ai" dynamic_api_key = get_secret("PERPLEXITYAI_API_KEY") - elif custom_llm_provider == "anyscale": + elif custom_llm_provider == "anyscale": # anyscale is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 api_base = "https://api.endpoints.anyscale.com/v1" dynamic_api_key = get_secret("ANYSCALE_API_KEY") - elif custom_llm_provider == "deepinfra": + elif custom_llm_provider == "deepinfra": # deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 api_base = "https://api.deepinfra.com/v1/openai" dynamic_api_key = get_secret("DEEPINFRA_API_KEY") @@ -2979,7 +3620,7 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_ return model, custom_llm_provider, dynamic_api_key, api_base # check if api base is a known openai compatible endpoint - if api_base: + if api_base: for endpoint in litellm.openai_compatible_endpoints: if endpoint in api_base: if endpoint == "api.perplexity.ai": @@ -2998,20 +3639,26 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_ # check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.) ## openai - chatcompletion + text completion - if model in litellm.open_ai_chat_completion_models or "ft:gpt-3.5-turbo" in model or model in litellm.openai_image_generation_models: + if ( + model in litellm.open_ai_chat_completion_models + or "ft:gpt-3.5-turbo" in model + or model in litellm.openai_image_generation_models + ): custom_llm_provider = "openai" elif model in litellm.open_ai_text_completion_models: custom_llm_provider = "text-completion-openai" - ## anthropic + ## anthropic elif model in litellm.anthropic_models: custom_llm_provider = "anthropic" ## cohere elif model in litellm.cohere_models or model in litellm.cohere_embedding_models: custom_llm_provider = "cohere" ## replicate - elif model in litellm.replicate_models or (":" in model and len(model)>64): + elif model in litellm.replicate_models or (":" in model and len(model) > 64): model_parts = model.split(":") - if len(model_parts) > 1 and len(model_parts[1])==64: ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3" + if ( + len(model_parts) > 1 and len(model_parts[1]) == 64 + ): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3" custom_llm_provider = "replicate" elif model in litellm.replicate_models: custom_llm_provider = "replicate" @@ -3021,22 +3668,22 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_ ## openrouter elif model in litellm.maritalk_models: custom_llm_provider = "maritalk" - ## vertex - text + chat + language (gemini) models - elif( - model in litellm.vertex_chat_models or - model in litellm.vertex_code_chat_models or - model in litellm.vertex_text_models or - model in litellm.vertex_code_text_models or - model in litellm.vertex_language_models + ## vertex - text + chat + language (gemini) models + elif ( + model in litellm.vertex_chat_models + or model in litellm.vertex_code_chat_models + or model in litellm.vertex_text_models + or model in litellm.vertex_code_text_models + or model in litellm.vertex_language_models ): custom_llm_provider = "vertex_ai" - ## ai21 + ## ai21 elif model in litellm.ai21_models: custom_llm_provider = "ai21" - ## aleph_alpha + ## aleph_alpha elif model in litellm.aleph_alpha_models: custom_llm_provider = "aleph_alpha" - ## baseten + ## baseten elif model in litellm.baseten_models: custom_llm_provider = "baseten" ## nlp_cloud @@ -3046,107 +3693,80 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_ elif model in litellm.petals_models: custom_llm_provider = "petals" ## bedrock - elif model in litellm.bedrock_models or model in litellm.bedrock_embedding_models: + elif ( + model in litellm.bedrock_models or model in litellm.bedrock_embedding_models + ): custom_llm_provider = "bedrock" # openai embeddings elif model in litellm.open_ai_embedding_models: custom_llm_provider = "openai" - if custom_llm_provider is None or custom_llm_provider=="": - print() # noqa - print("\033[1;31mProvider List: https://docs.litellm.ai/docs/providers\033[0m") # noqa - print() # noqa + if custom_llm_provider is None or custom_llm_provider == "": + print() # noqa + print( + "\033[1;31mProvider List: https://docs.litellm.ai/docs/providers\033[0m" + ) # noqa + print() # noqa error_str = f"LLM Provider NOT provided. Pass in the LLM provider you are trying to call. You passed model={model}\n Pass model as E.g. For 'Huggingface' inference endpoints pass in `completion(model='huggingface/starcoder',..)` Learn more: https://docs.litellm.ai/docs/providers" # maps to openai.NotFoundError, this is raised when openai does not recognize the llm - raise litellm.exceptions.NotFoundError( # type: ignore + raise litellm.exceptions.NotFoundError( # type: ignore message=error_str, model=model, response=httpx.Response( status_code=404, - content=error_str, - request=httpx.request(method="completion", url="https://github.com/BerriAI/litellm") # type: ignore + content=error_str, + request=httpx.request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore ), - llm_provider="" + llm_provider="", ) return model, custom_llm_provider, dynamic_api_key, api_base - except Exception as e: + except Exception as e: raise e def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]): - api_key = (dynamic_api_key or litellm.api_key) - # openai + api_key = dynamic_api_key or litellm.api_key + # openai if llm_provider == "openai" or llm_provider == "text-completion-openai": - api_key = ( - api_key or - litellm.openai_key or - get_secret("OPENAI_API_KEY") - ) - # anthropic + api_key = api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") + # anthropic elif llm_provider == "anthropic": - api_key = ( - api_key or - litellm.anthropic_key or - get_secret("ANTHROPIC_API_KEY") - ) - # ai21 + api_key = api_key or litellm.anthropic_key or get_secret("ANTHROPIC_API_KEY") + # ai21 elif llm_provider == "ai21": - api_key = ( - api_key or - litellm.ai21_key or - get_secret("AI211_API_KEY") - ) - # aleph_alpha + api_key = api_key or litellm.ai21_key or get_secret("AI211_API_KEY") + # aleph_alpha elif llm_provider == "aleph_alpha": api_key = ( - api_key or - litellm.aleph_alpha_key or - get_secret("ALEPH_ALPHA_API_KEY") + api_key or litellm.aleph_alpha_key or get_secret("ALEPH_ALPHA_API_KEY") ) - # baseten + # baseten elif llm_provider == "baseten": - api_key = ( - api_key or - litellm.baseten_key or - get_secret("BASETEN_API_KEY") - ) - # cohere + api_key = api_key or litellm.baseten_key or get_secret("BASETEN_API_KEY") + # cohere elif llm_provider == "cohere": - api_key = ( - api_key or - litellm.cohere_key or - get_secret("COHERE_API_KEY") - ) - # huggingface + api_key = api_key or litellm.cohere_key or get_secret("COHERE_API_KEY") + # huggingface elif llm_provider == "huggingface": api_key = ( - api_key or - litellm.huggingface_key or - get_secret("HUGGINGFACE_API_KEY") + api_key or litellm.huggingface_key or get_secret("HUGGINGFACE_API_KEY") ) - # nlp_cloud + # nlp_cloud elif llm_provider == "nlp_cloud": - api_key = ( - api_key or - litellm.nlp_cloud_key or - get_secret("NLP_CLOUD_API_KEY") - ) - # replicate + api_key = api_key or litellm.nlp_cloud_key or get_secret("NLP_CLOUD_API_KEY") + # replicate elif llm_provider == "replicate": - api_key = ( - api_key or - litellm.replicate_key or - get_secret("REPLICATE_API_KEY") - ) - # together_ai + api_key = api_key or litellm.replicate_key or get_secret("REPLICATE_API_KEY") + # together_ai elif llm_provider == "together_ai": api_key = ( - api_key or - litellm.togetherai_api_key or - get_secret("TOGETHERAI_API_KEY") or - get_secret("TOGETHER_AI_TOKEN") + api_key + or litellm.togetherai_api_key + or get_secret("TOGETHERAI_API_KEY") + or get_secret("TOGETHER_AI_TOKEN") ) return api_key + def get_max_tokens(model: str): """ Get the maximum number of tokens allowed for a given model. @@ -3164,6 +3784,7 @@ def get_max_tokens(model: str): >>> get_max_tokens("gpt-4") 8192 """ + def _get_max_position_embeddings(model_name): # Construct the URL for the config.json file config_url = f"https://huggingface.co/{model_name}/raw/main/config.json" @@ -3189,19 +3810,21 @@ def get_max_tokens(model: str): try: if model in litellm.model_cost: return litellm.model_cost[model]["max_tokens"] - model, custom_llm_provider, _, _ = get_llm_provider(model=model) - if custom_llm_provider == "huggingface": + model, custom_llm_provider, _, _ = get_llm_provider(model=model) + if custom_llm_provider == "huggingface": max_tokens = _get_max_position_embeddings(model_name=model) return max_tokens - else: + else: raise Exception() except: - raise Exception("This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json") + raise Exception( + "This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" + ) def get_model_info(model: str): """ - Get a dict for the maximum tokens (context window), + Get a dict for the maximum tokens (context window), input_cost_per_token, output_cost_per_token for a given model. Parameters: @@ -3228,6 +3851,7 @@ def get_model_info(model: str): "mode": "chat" } """ + def _get_max_position_embeddings(model_name): # Construct the URL for the config.json file config_url = f"https://huggingface.co/{model_name}/raw/main/config.json" @@ -3249,30 +3873,34 @@ def get_model_info(model: str): return None except requests.exceptions.RequestException as e: return None + try: azure_llms = { "gpt-35-turbo": "azure/gpt-3.5-turbo", "gpt-35-turbo-16k": "azure/gpt-3.5-turbo-16k", - "gpt-35-turbo-instruct": "azure/gpt-3.5-turbo-instruct" + "gpt-35-turbo-instruct": "azure/gpt-3.5-turbo-instruct", } - if model in azure_llms: + if model in azure_llms: model = azure_llms[model] if model in litellm.model_cost: return litellm.model_cost[model] - model, custom_llm_provider, _, _ = get_llm_provider(model=model) - if custom_llm_provider == "huggingface": + model, custom_llm_provider, _, _ = get_llm_provider(model=model) + if custom_llm_provider == "huggingface": max_tokens = _get_max_position_embeddings(model_name=model) return { "max_tokens": max_tokens, "input_cost_per_token": 0, "output_cost_per_token": 0, "litellm_provider": "huggingface", - "mode": "chat" + "mode": "chat", } - else: + else: raise Exception() except: - raise Exception("This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json") + raise Exception( + "This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" + ) + def json_schema_type(python_type_name: str): """Converts standard python types to json schema types @@ -3299,6 +3927,7 @@ def json_schema_type(python_type_name: str): return python_to_json_schema_types.get(python_type_name, "string") + def function_to_dict(input_function): # noqa: C901 """Using type hints and numpy-styled docstring, produce a dictionnary usable for OpenAI function calling @@ -3388,6 +4017,7 @@ def function_to_dict(input_function): # noqa: C901 return result + def load_test_model( model: str, custom_llm_provider: str = "", @@ -3430,13 +4060,14 @@ def load_test_model( "exception": e, } -def validate_environment(model: Optional[str]=None) -> dict: + +def validate_environment(model: Optional[str] = None) -> dict: """ Checks if the environment variables are valid for the given model. - + Args: model (Optional[str]): The name of the model. Defaults to None. - + Returns: dict: A dictionary containing the following keys: - keys_in_environment (bool): True if all the required keys are present in the environment, False otherwise. @@ -3446,7 +4077,10 @@ def validate_environment(model: Optional[str]=None) -> dict: missing_keys: List[str] = [] if model is None: - return {"keys_in_environment": keys_in_environment, "missing_keys": missing_keys} + return { + "keys_in_environment": keys_in_environment, + "missing_keys": missing_keys, + } ## EXTRACT LLM PROVIDER - if model name provided try: custom_llm_provider = get_llm_provider(model=model) @@ -3457,7 +4091,7 @@ def validate_environment(model: Optional[str]=None) -> dict: # custom_llm_provider = model.split("/", 1)[0] # model = model.split("/", 1)[1] # custom_llm_provider_passed_in = True - + if custom_llm_provider: if custom_llm_provider == "openai": if "OPENAI_API_KEY" in os.environ: @@ -3465,12 +4099,16 @@ def validate_environment(model: Optional[str]=None) -> dict: else: missing_keys.append("OPENAI_API_KEY") elif custom_llm_provider == "azure": - if ("AZURE_API_BASE" in os.environ + if ( + "AZURE_API_BASE" in os.environ and "AZURE_API_VERSION" in os.environ - and "AZURE_API_KEY" in os.environ): + and "AZURE_API_KEY" in os.environ + ): keys_in_environment = True else: - missing_keys.extend(["AZURE_API_BASE", "AZURE_API_VERSION", "AZURE_API_KEY"]) + missing_keys.extend( + ["AZURE_API_BASE", "AZURE_API_VERSION", "AZURE_API_KEY"] + ) elif custom_llm_provider == "anthropic": if "ANTHROPIC_API_KEY" in os.environ: keys_in_environment = True @@ -3492,8 +4130,7 @@ def validate_environment(model: Optional[str]=None) -> dict: else: missing_keys.append("OPENROUTER_API_KEY") elif custom_llm_provider == "vertex_ai": - if ("VERTEXAI_PROJECT" in os.environ - and "VERTEXAI_LOCATION" in os.environ): + if "VERTEXAI_PROJECT" in os.environ and "VERTEXAI_LOCATION" in os.environ: keys_in_environment = True else: missing_keys.extend(["VERTEXAI_PROJECT", "VERTEXAI_PROJECT"]) @@ -3527,20 +4164,26 @@ def validate_environment(model: Optional[str]=None) -> dict: keys_in_environment = True else: missing_keys.append("NLP_CLOUD_API_KEY") - elif custom_llm_provider == "bedrock": - if "AWS_ACCESS_KEY_ID" in os.environ and "AWS_SECRET_ACCESS_KEY" in os.environ: + elif custom_llm_provider == "bedrock": + if ( + "AWS_ACCESS_KEY_ID" in os.environ + and "AWS_SECRET_ACCESS_KEY" in os.environ + ): keys_in_environment = True else: missing_keys.append("AWS_ACCESS_KEY_ID") missing_keys.append("AWS_SECRET_ACCESS_KEY") else: ## openai - chatcompletion + text completion - if model in litellm.open_ai_chat_completion_models or litellm.open_ai_text_completion_models: + if ( + model in litellm.open_ai_chat_completion_models + or litellm.open_ai_text_completion_models + ): if "OPENAI_API_KEY" in os.environ: keys_in_environment = True else: missing_keys.append("OPENAI_API_KEY") - ## anthropic + ## anthropic elif model in litellm.anthropic_models: if "ANTHROPIC_API_KEY" in os.environ: keys_in_environment = True @@ -3566,36 +4209,35 @@ def validate_environment(model: Optional[str]=None) -> dict: missing_keys.append("OPENROUTER_API_KEY") ## vertex - text + chat models elif model in litellm.vertex_chat_models or model in litellm.vertex_text_models: - if ("VERTEXAI_PROJECT" in os.environ - and "VERTEXAI_LOCATION" in os.environ): + if "VERTEXAI_PROJECT" in os.environ and "VERTEXAI_LOCATION" in os.environ: keys_in_environment = True else: missing_keys.extend(["VERTEXAI_PROJECT", "VERTEXAI_PROJECT"]) - ## huggingface + ## huggingface elif model in litellm.huggingface_models: if "HUGGINGFACE_API_KEY" in os.environ: keys_in_environment = True else: missing_keys.append("HUGGINGFACE_API_KEY") - ## ai21 + ## ai21 elif model in litellm.ai21_models: if "AI21_API_KEY" in os.environ: keys_in_environment = True else: missing_keys.append("AI21_API_KEY") - ## together_ai + ## together_ai elif model in litellm.together_ai_models: if "TOGETHERAI_API_KEY" in os.environ: keys_in_environment = True else: missing_keys.append("TOGETHERAI_API_KEY") - ## aleph_alpha + ## aleph_alpha elif model in litellm.aleph_alpha_models: if "ALEPH_ALPHA_API_KEY" in os.environ: keys_in_environment = True else: missing_keys.append("ALEPH_ALPHA_API_KEY") - ## baseten + ## baseten elif model in litellm.baseten_models: if "BASETEN_API_KEY" in os.environ: keys_in_environment = True @@ -3607,7 +4249,8 @@ def validate_environment(model: Optional[str]=None) -> dict: keys_in_environment = True else: missing_keys.append("NLP_CLOUD_API_KEY") - return {"keys_in_environment": keys_in_environment, "missing_keys": missing_keys} + return {"keys_in_environment": keys_in_environment, "missing_keys": missing_keys} + def set_callbacks(callback_list, function_id=None): global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, dynamoLogger @@ -3701,6 +4344,7 @@ def set_callbacks(callback_list, function_id=None): except Exception as e: raise e + # NOTE: DEPRECATING this in favor of using failure_handler() in Logging: def handle_failure(exception, traceback_exception, start_time, end_time, args, kwargs): global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger @@ -3844,7 +4488,8 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k exception_logging(logger_fn=user_logger_fn, exception=e) pass -async def convert_to_streaming_response_async(response_object: Optional[dict]=None): + +async def convert_to_streaming_response_async(response_object: Optional[dict] = None): """ Asynchronously converts a response object to a streaming response. @@ -3875,7 +4520,7 @@ async def convert_to_streaming_response_async(response_object: Optional[dict]=No content=choice["message"].get("content", None), role=choice["message"]["role"], function_call=choice["message"].get("function_call", None), - tool_calls=choice["message"].get("tool_calls", None) + tool_calls=choice["message"].get("tool_calls", None), ) finish_reason = choice.get("finish_reason", None) @@ -3891,10 +4536,9 @@ async def convert_to_streaming_response_async(response_object: Optional[dict]=No model_response_object.usage = Usage( completion_tokens=response_object["usage"].get("completion_tokens", 0), prompt_tokens=response_object["usage"].get("prompt_tokens", 0), - total_tokens=response_object["usage"].get("total_tokens", 0) + total_tokens=response_object["usage"].get("total_tokens", 0), ) - if "id" in response_object: model_response_object.id = response_object["id"] @@ -3907,19 +4551,20 @@ async def convert_to_streaming_response_async(response_object: Optional[dict]=No yield model_response_object await asyncio.sleep(0) -def convert_to_streaming_response(response_object: Optional[dict]=None): + +def convert_to_streaming_response(response_object: Optional[dict] = None): # used for yielding Cache hits when stream == True if response_object is None: raise Exception("Error in response object format") model_response_object = ModelResponse(stream=True) - choice_list=[] - for idx, choice in enumerate(response_object["choices"]): + choice_list = [] + for idx, choice in enumerate(response_object["choices"]): delta = Delta( - content=choice["message"].get("content", None), - role=choice["message"]["role"], - function_call=choice["message"].get("function_call", None), - tool_calls=choice["message"].get("tool_calls", None) + content=choice["message"].get("content", None), + role=choice["message"]["role"], + function_call=choice["message"].get("function_call", None), + tool_calls=choice["message"].get("tool_calls", None), ) finish_reason = choice.get("finish_reason", None) if finish_reason == None: @@ -3930,100 +4575,118 @@ def convert_to_streaming_response(response_object: Optional[dict]=None): model_response_object.choices = choice_list if "usage" in response_object and response_object["usage"] is not None: - model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore - model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore - model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore + model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore + model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore + model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore - if "id" in response_object: + if "id" in response_object: model_response_object.id = response_object["id"] if "system_fingerprint" in response_object: model_response_object.system_fingerprint = response_object["system_fingerprint"] - if "model" in response_object: + if "model" in response_object: model_response_object.model = response_object["model"] yield model_response_object -def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse, ImageResponse]]=None, response_type: Literal["completion", "embedding", "image_generation"] = "completion", stream = False): - try: - if response_type == "completion" and (model_response_object is None or isinstance(model_response_object, ModelResponse)): - if response_object is None or model_response_object is None: - raise Exception("Error in response object format") - if stream == True: - # for returning cached responses, we need to yield a generator - return convert_to_streaming_response( - response_object=response_object - ) - choice_list=[] - for idx, choice in enumerate(response_object["choices"]): - message = Message( - content=choice["message"].get("content", None), - role=choice["message"]["role"], - function_call=choice["message"].get("function_call", None), - tool_calls=choice["message"].get("tool_calls", None) - ) - finish_reason = choice.get("finish_reason", None) - if finish_reason == None: - # gpt-4 vision can return 'finish_reason' or 'finish_details' - finish_reason = choice.get("finish_details") - choice = Choices(finish_reason=finish_reason, index=idx, message=message) - choice_list.append(choice) - model_response_object.choices = choice_list +def convert_to_model_response_object( + response_object: Optional[dict] = None, + model_response_object: Optional[ + Union[ModelResponse, EmbeddingResponse, ImageResponse] + ] = None, + response_type: Literal[ + "completion", "embedding", "image_generation" + ] = "completion", + stream=False, +): + try: + if response_type == "completion" and ( + model_response_object is None + or isinstance(model_response_object, ModelResponse) + ): + if response_object is None or model_response_object is None: + raise Exception("Error in response object format") + if stream == True: + # for returning cached responses, we need to yield a generator + return convert_to_streaming_response(response_object=response_object) + choice_list = [] + for idx, choice in enumerate(response_object["choices"]): + message = Message( + content=choice["message"].get("content", None), + role=choice["message"]["role"], + function_call=choice["message"].get("function_call", None), + tool_calls=choice["message"].get("tool_calls", None), + ) + finish_reason = choice.get("finish_reason", None) + if finish_reason == None: + # gpt-4 vision can return 'finish_reason' or 'finish_details' + finish_reason = choice.get("finish_details") + choice = Choices( + finish_reason=finish_reason, index=idx, message=message + ) + choice_list.append(choice) + model_response_object.choices = choice_list - if "usage" in response_object and response_object["usage"] is not None: - model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore - model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore - model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore + if "usage" in response_object and response_object["usage"] is not None: + model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore + model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore + model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore - if "id" in response_object: - model_response_object.id = response_object["id"] - - if "system_fingerprint" in response_object: - model_response_object.system_fingerprint = response_object["system_fingerprint"] + if "id" in response_object: + model_response_object.id = response_object["id"] - if "model" in response_object: - model_response_object.model = response_object["model"] - return model_response_object - elif response_type == "embedding" and (model_response_object is None or isinstance(model_response_object, EmbeddingResponse)): - if response_object is None: - raise Exception("Error in response object format") - - if model_response_object is None: - model_response_object = EmbeddingResponse() + if "system_fingerprint" in response_object: + model_response_object.system_fingerprint = response_object[ + "system_fingerprint" + ] - if "model" in response_object: - model_response_object.model = response_object["model"] - - if "object" in response_object: - model_response_object.object = response_object["object"] + if "model" in response_object: + model_response_object.model = response_object["model"] + return model_response_object + elif response_type == "embedding" and ( + model_response_object is None + or isinstance(model_response_object, EmbeddingResponse) + ): + if response_object is None: + raise Exception("Error in response object format") - + if model_response_object is None: + model_response_object = EmbeddingResponse() + + if "model" in response_object: + model_response_object.model = response_object["model"] + + if "object" in response_object: + model_response_object.object = response_object["object"] + + model_response_object.data = response_object["data"] + + if "usage" in response_object and response_object["usage"] is not None: + model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore + model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore + model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore + + return model_response_object + elif response_type == "image_generation" and ( + model_response_object is None + or isinstance(model_response_object, ImageResponse) + ): + if response_object is None: + raise Exception("Error in response object format") + + if model_response_object is None: + model_response_object = ImageResponse() + + if "created" in response_object: + model_response_object.created = response_object["created"] + + if "data" in response_object: model_response_object.data = response_object["data"] - if "usage" in response_object and response_object["usage"] is not None: - model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore - model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore - model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore - - - return model_response_object - elif response_type == "image_generation" and (model_response_object is None or isinstance(model_response_object, ImageResponse)): - if response_object is None: - raise Exception("Error in response object format") - - if model_response_object is None: - model_response_object = ImageResponse() - - if "created" in response_object: - model_response_object.created = response_object["created"] - - if "data" in response_object: - model_response_object.data = response_object["data"] - - return model_response_object - except Exception as e: - raise Exception(f"Invalid response object {e}") + return model_response_object + except Exception as e: + raise Exception(f"Invalid response object {e}") # NOTE: DEPRECATING this in favor of using success_handler() in Logging: @@ -4129,6 +4792,7 @@ def valid_model(model): except: raise BadRequestError(message="", model=model, llm_provider="") + def check_valid_key(model: str, api_key: str): """ Checks if a given API key is valid for a specific model by making a litellm.completion call with max_tokens=10 @@ -4142,16 +4806,19 @@ def check_valid_key(model: str, api_key: str): """ messages = [{"role": "user", "content": "Hey, how's it going?"}] try: - litellm.completion(model=model, messages=messages, api_key=api_key, max_tokens=10) + litellm.completion( + model=model, messages=messages, api_key=api_key, max_tokens=10 + ) return True except AuthenticationError as e: return False except Exception as e: return False -def _should_retry(status_code: int): + +def _should_retry(status_code: int): """ - Reimplementation of openai's should retry logic, since that one can't be imported. + Reimplementation of openai's should retry logic, since that one can't be imported. https://github.com/openai/openai-python/blob/af67cfab4210d8e497c05390ce14f39105c77519/src/openai/_base_client.py#L639 """ # If the server explicitly says whether or not to retry, obey. @@ -4173,13 +4840,20 @@ def _should_retry(status_code: int): return False -def _calculate_retry_after(remaining_retries: int, max_retries: int, response_headers: Optional[httpx.Headers]=None, min_timeout: int = 0): + +def _calculate_retry_after( + remaining_retries: int, + max_retries: int, + response_headers: Optional[httpx.Headers] = None, + min_timeout: int = 0, +): """ Reimplementation of openai's calculate retry after, since that one can't be imported. https://github.com/openai/openai-python/blob/af67cfab4210d8e497c05390ce14f39105c77519/src/openai/_base_client.py#L631 """ try: - import email # openai import + import email # openai import + # About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After # # ". See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After#syntax for @@ -4200,7 +4874,7 @@ def _calculate_retry_after(remaining_retries: int, max_retries: int, response_he except Exception: retry_after = -1 - + # If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says. if 0 < retry_after <= 60: return retry_after @@ -4217,6 +4891,7 @@ def _calculate_retry_after(remaining_retries: int, max_retries: int, response_he timeout = sleep_seconds * jitter return timeout if timeout >= min_timeout else min_timeout + # integration helper function def modify_integration(integration_name, integration_params): global supabaseClient @@ -4226,7 +4901,12 @@ def modify_integration(integration_name, integration_params): # custom prompt helper function -def register_prompt_template(model: str, roles: dict, initial_prompt_value: str = "", final_prompt_value: str = ""): +def register_prompt_template( + model: str, + roles: dict, + initial_prompt_value: str = "", + final_prompt_value: str = "", +): """ Register a prompt template to follow your custom format for a given model @@ -4240,19 +4920,19 @@ def register_prompt_template(model: str, roles: dict, initial_prompt_value: str dict: The updated custom prompt dictionary. Example usage: ``` - import litellm + import litellm litellm.register_prompt_template( - model="llama-2", + model="llama-2", initial_prompt_value="You are a good assistant" # [OPTIONAL] - roles={ + roles={ "system": { "pre_message": "[INST] <>\n", # [OPTIONAL] "post_message": "\n<>\n [/INST]\n" # [OPTIONAL] }, - "user": { + "user": { "pre_message": "[INST] ", # [OPTIONAL] "post_message": " [/INST]" # [OPTIONAL] - }, + }, "assistant": { "pre_message": "\n" # [OPTIONAL] "post_message": "\n" # [OPTIONAL] @@ -4266,11 +4946,12 @@ def register_prompt_template(model: str, roles: dict, initial_prompt_value: str litellm.custom_prompt_dict[model] = { "roles": roles, "initial_prompt_value": initial_prompt_value, - "final_prompt_value": final_prompt_value + "final_prompt_value": final_prompt_value, } return litellm.custom_prompt_dict -####### DEPRECATED ################ + +####### DEPRECATED ################ def get_all_keys(llm_provider=None): @@ -4359,20 +5040,25 @@ def get_model_list(): f"[Non-Blocking Error] get_model_list error - {traceback.format_exc()}" ) + ####### EXCEPTION MAPPING ################ def exception_type( - model, - original_exception, - custom_llm_provider, - completion_kwargs={}, - ): + model, + original_exception, + custom_llm_provider, + completion_kwargs={}, +): global user_logger_fn, liteDebuggerClient exception_mapping_worked = False if litellm.suppress_debug_info is False: - print() # noqa - print("\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m") # noqa - print("LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'.") # noqa - print() # noqa + print() # noqa + print( + "\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m" + ) # noqa + print( + "LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'." + ) # noqa + print() # noqa try: if model: error_str = str(original_exception) @@ -4380,39 +5066,53 @@ def exception_type( exception_type = type(original_exception).__name__ else: exception_type = "" - - if "Request Timeout Error" in error_str or "Request timed out" in error_str: + + if "Request Timeout Error" in error_str or "Request timed out" in error_str: exception_mapping_worked = True raise Timeout( message=f"APITimeoutError - Request timed out", model=model, - llm_provider=custom_llm_provider + llm_provider=custom_llm_provider, ) - if custom_llm_provider == "openai" or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "custom_openai" or custom_llm_provider in litellm.openai_compatible_providers: - if "This model's maximum context length is" in error_str or "Request too large" in error_str: + if ( + custom_llm_provider == "openai" + or custom_llm_provider == "text-completion-openai" + or custom_llm_provider == "custom_openai" + or custom_llm_provider in litellm.openai_compatible_providers + ): + if ( + "This model's maximum context length is" in error_str + or "Request too large" in error_str + ): exception_mapping_worked = True raise ContextWindowExceededError( message=f"OpenAIException - {original_exception.message}", llm_provider="openai", model=model, - response=original_exception.response + response=original_exception.response, ) - elif "invalid_request_error" in error_str and "model_not_found" in error_str: + elif ( + "invalid_request_error" in error_str + and "model_not_found" in error_str + ): exception_mapping_worked = True raise NotFoundError( message=f"OpenAIException - {original_exception.message}", llm_provider="openai", model=model, - response=original_exception.response + response=original_exception.response, ) - elif "invalid_request_error" in error_str and "Incorrect API key provided" not in error_str: + elif ( + "invalid_request_error" in error_str + and "Incorrect API key provided" not in error_str + ): exception_mapping_worked = True raise BadRequestError( message=f"OpenAIException - {original_exception.message}", llm_provider="openai", model=model, - response=original_exception.response + response=original_exception.response, ) elif hasattr(original_exception, "status_code"): exception_mapping_worked = True @@ -4422,7 +5122,7 @@ def exception_type( message=f"OpenAIException - {original_exception.message}", llm_provider="openai", model=model, - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 404: exception_mapping_worked = True @@ -4430,7 +5130,7 @@ def exception_type( message=f"OpenAIException - {original_exception.message}", model=model, llm_provider="openai", - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -4445,7 +5145,7 @@ def exception_type( message=f"OpenAIException - {original_exception.message}", model=model, llm_provider="openai", - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -4453,17 +5153,17 @@ def exception_type( message=f"OpenAIException - {original_exception.message}", model=model, llm_provider="openai", - response=original_exception.response + response=original_exception.response, ) - elif original_exception.status_code == 503: + elif original_exception.status_code == 503: exception_mapping_worked = True raise ServiceUnavailableError( message=f"OpenAIException - {original_exception.message}", model=model, llm_provider="openai", - response=original_exception.response + response=original_exception.response, ) - elif original_exception.status_code == 504: # gateway timeout error + elif original_exception.status_code == 504: # gateway timeout error exception_mapping_worked = True raise Timeout( message=f"OpenAIException - {original_exception.message}", @@ -4473,11 +5173,11 @@ def exception_type( else: exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"OpenAIException - {original_exception.message}", llm_provider="openai", model=model, - request=original_exception.request + request=original_exception.request, ) else: # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors @@ -4485,25 +5185,28 @@ def exception_type( __cause__=original_exception.__cause__, llm_provider=custom_llm_provider, model=model, - request=original_exception.request + request=original_exception.request, ) elif custom_llm_provider == "anthropic": # one of the anthropics if hasattr(original_exception, "message"): - if "prompt is too long" in original_exception.message or "prompt: length" in original_exception.message: + if ( + "prompt is too long" in original_exception.message + or "prompt: length" in original_exception.message + ): exception_mapping_worked = True raise ContextWindowExceededError( - message=original_exception.message, + message=original_exception.message, model=model, llm_provider="anthropic", - response=original_exception.response + response=original_exception.response, ) if "Invalid API Key" in original_exception.message: exception_mapping_worked = True raise AuthenticationError( - message=original_exception.message, + message=original_exception.message, model=model, llm_provider="anthropic", - response=original_exception.response + response=original_exception.response, ) if hasattr(original_exception, "status_code"): print_verbose(f"status_code: {original_exception.status_code}") @@ -4513,15 +5216,18 @@ def exception_type( message=f"AnthropicException - {original_exception.message}", llm_provider="anthropic", model=model, - response=original_exception.response + response=original_exception.response, ) - elif original_exception.status_code == 400 or original_exception.status_code == 413: + elif ( + original_exception.status_code == 400 + or original_exception.status_code == 413 + ): exception_mapping_worked = True raise BadRequestError( message=f"AnthropicException - {original_exception.message}", model=model, llm_provider="anthropic", - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -4529,7 +5235,7 @@ def exception_type( message=f"AnthropicException - {original_exception.message}", model=model, llm_provider="anthropic", - request=original_exception.request + request=original_exception.request, ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -4537,7 +5243,7 @@ def exception_type( message=f"AnthropicException - {original_exception.message}", llm_provider="anthropic", model=model, - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 500: exception_mapping_worked = True @@ -4545,7 +5251,7 @@ def exception_type( message=f"AnthropicException - {original_exception.message}", llm_provider="anthropic", model=model, - response=original_exception.response + response=original_exception.response, ) else: exception_mapping_worked = True @@ -4554,7 +5260,7 @@ def exception_type( message=f"AnthropicException - {original_exception.message}", llm_provider="anthropic", model=model, - request=original_exception.request + request=original_exception.request, ) elif custom_llm_provider == "replicate": if "Incorrect authentication token" in error_str: @@ -4563,7 +5269,7 @@ def exception_type( message=f"ReplicateException - {error_str}", llm_provider="replicate", model=model, - response=original_exception.response + response=original_exception.response, ) elif "input is too long" in error_str: exception_mapping_worked = True @@ -4571,7 +5277,7 @@ def exception_type( message=f"ReplicateException - {error_str}", model=model, llm_provider="replicate", - response=original_exception.response + response=original_exception.response, ) elif exception_type == "ModelError": exception_mapping_worked = True @@ -4579,7 +5285,7 @@ def exception_type( message=f"ReplicateException - {error_str}", model=model, llm_provider="replicate", - response=original_exception.response + response=original_exception.response, ) elif "Request was throttled" in error_str: exception_mapping_worked = True @@ -4587,7 +5293,7 @@ def exception_type( message=f"ReplicateException - {error_str}", llm_provider="replicate", model=model, - response=original_exception.response + response=original_exception.response, ) elif hasattr(original_exception, "status_code"): if original_exception.status_code == 401: @@ -4596,15 +5302,19 @@ def exception_type( message=f"ReplicateException - {original_exception.message}", llm_provider="replicate", model=model, - response=original_exception.response + response=original_exception.response, ) - elif original_exception.status_code == 400 or original_exception.status_code == 422 or original_exception.status_code == 413: + elif ( + original_exception.status_code == 400 + or original_exception.status_code == 422 + or original_exception.status_code == 413 + ): exception_mapping_worked = True raise BadRequestError( message=f"ReplicateException - {original_exception.message}", model=model, llm_provider="replicate", - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -4612,7 +5322,7 @@ def exception_type( message=f"ReplicateException - {original_exception.message}", model=model, llm_provider="replicate", - request=original_exception.request + request=original_exception.request, ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -4620,7 +5330,7 @@ def exception_type( message=f"ReplicateException - {original_exception.message}", llm_provider="replicate", model=model, - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 500: exception_mapping_worked = True @@ -4628,48 +5338,60 @@ def exception_type( message=f"ReplicateException - {original_exception.message}", llm_provider="replicate", model=model, - response=original_exception.response + response=original_exception.response, ) exception_mapping_worked = True raise APIError( - status_code=500, + status_code=500, message=f"ReplicateException - {str(original_exception)}", llm_provider="replicate", model=model, - request=original_exception.request + request=original_exception.request, ) elif custom_llm_provider == "bedrock": - if "too many tokens" in error_str or "expected maxLength:" in error_str or "Input is too long" in error_str or "Too many input tokens" in error_str: + if ( + "too many tokens" in error_str + or "expected maxLength:" in error_str + or "Input is too long" in error_str + or "Too many input tokens" in error_str + ): exception_mapping_worked = True raise ContextWindowExceededError( message=f"BedrockException: Context Window Error - {error_str}", - model=model, + model=model, llm_provider="bedrock", - response=original_exception.response + response=original_exception.response, ) if "Malformed input request" in error_str: exception_mapping_worked = True raise BadRequestError( - message=f"BedrockException - {error_str}", - model=model, + message=f"BedrockException - {error_str}", + model=model, llm_provider="bedrock", - response=original_exception.response + response=original_exception.response, ) - if "Unable to locate credentials" in error_str or "The security token included in the request is invalid" in error_str: + if ( + "Unable to locate credentials" in error_str + or "The security token included in the request is invalid" + in error_str + ): exception_mapping_worked = True raise AuthenticationError( - message=f"BedrockException Invalid Authentication - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response + message=f"BedrockException Invalid Authentication - {error_str}", + model=model, + llm_provider="bedrock", + response=original_exception.response, ) - if "throttlingException" in error_str or "ThrottlingException" in error_str: + if ( + "throttlingException" in error_str + or "ThrottlingException" in error_str + ): exception_mapping_worked = True raise RateLimitError( - message=f"BedrockException: Rate Limit Error - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response + message=f"BedrockException: Rate Limit Error - {error_str}", + model=model, + llm_provider="bedrock", + response=original_exception.response, ) if hasattr(original_exception, "status_code"): if original_exception.status_code == 500: @@ -4678,7 +5400,7 @@ def exception_type( message=f"BedrockException - {original_exception.message}", llm_provider="bedrock", model=model, - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 401: exception_mapping_worked = True @@ -4686,49 +5408,55 @@ def exception_type( message=f"BedrockException - {original_exception.message}", llm_provider="bedrock", model=model, - response=original_exception.response + response=original_exception.response, ) - elif custom_llm_provider == "sagemaker": + elif custom_llm_provider == "sagemaker": if "Unable to locate credentials" in error_str: exception_mapping_worked = True raise BadRequestError( - message=f"SagemakerException - {error_str}", - model=model, + message=f"SagemakerException - {error_str}", + model=model, llm_provider="sagemaker", - response=original_exception.response + response=original_exception.response, ) - elif "Input validation error: `best_of` must be > 0 and <= 2" in error_str: + elif ( + "Input validation error: `best_of` must be > 0 and <= 2" + in error_str + ): exception_mapping_worked = True raise BadRequestError( - message=f"SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints", - model=model, + message=f"SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints", + model=model, llm_provider="sagemaker", - response=original_exception.response + response=original_exception.response, ) elif custom_llm_provider == "vertex_ai": - if "Vertex AI API has not been used in project" in error_str or "Unable to find your project" in error_str: + if ( + "Vertex AI API has not been used in project" in error_str + or "Unable to find your project" in error_str + ): exception_mapping_worked = True raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, + message=f"VertexAIException - {error_str}", + model=model, llm_provider="vertex_ai", - response=original_exception.response + response=original_exception.response, ) - elif "403" in error_str: + elif "403" in error_str: exception_mapping_worked = True raise U( - message=f"VertexAIException - {error_str}", - model=model, + message=f"VertexAIException - {error_str}", + model=model, llm_provider="vertex_ai", - response=original_exception.response + response=original_exception.response, ) elif "The response was blocked." in error_str: exception_mapping_worked = True raise UnprocessableEntityError( - message=f"VertexAIException - {error_str}", - model=model, + message=f"VertexAIException - {error_str}", + model=model, llm_provider="vertex_ai", - response=original_exception.response + response=original_exception.response, ) if hasattr(original_exception, "status_code"): if original_exception.status_code == 400: @@ -4737,16 +5465,16 @@ def exception_type( message=f"VertexAIException - {error_str}", model=model, llm_provider="vertex_ai", - response=original_exception.response + response=original_exception.response, ) - if original_exception.status_code == 500: + if original_exception.status_code == 500: exception_mapping_worked = True raise APIError( message=f"VertexAIException - {error_str}", status_code=500, model=model, llm_provider="vertex_ai", - request=original_exception.request + request=original_exception.request, ) elif custom_llm_provider == "palm": if "503 Getting metadata" in error_str: @@ -4754,10 +5482,10 @@ def exception_type( # 503 Getting metadata from plugin failed with error: Reauthentication is needed. Please run `gcloud auth application-default login` to reauthenticate. exception_mapping_worked = True raise BadRequestError( - message=f"PalmException - Invalid api key", - model=model, + message=f"PalmException - Invalid api key", + model=model, llm_provider="palm", - response=original_exception.response + response=original_exception.response, ) if "400 Request payload size exceeds" in error_str: exception_mapping_worked = True @@ -4765,7 +5493,7 @@ def exception_type( message=f"PalmException - {error_str}", model=model, llm_provider="palm", - response=original_exception.response + response=original_exception.response, ) if hasattr(original_exception, "status_code"): if original_exception.status_code == 400: @@ -4774,7 +5502,7 @@ def exception_type( message=f"PalmException - {error_str}", model=model, llm_provider="palm", - response=original_exception.response + response=original_exception.response, ) # Dailed: Error occurred: 400 Request payload size exceeds the limit: 20000 bytes elif custom_llm_provider == "cohere": # Cohere @@ -4787,7 +5515,7 @@ def exception_type( message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response + response=original_exception.response, ) elif "too many tokens" in error_str: exception_mapping_worked = True @@ -4795,16 +5523,19 @@ def exception_type( message=f"CohereException - {original_exception.message}", model=model, llm_provider="cohere", - response=original_exception.response + response=original_exception.response, ) elif hasattr(original_exception, "status_code"): - if original_exception.status_code == 400 or original_exception.status_code == 498: + if ( + original_exception.status_code == 400 + or original_exception.status_code == 498 + ): exception_mapping_worked = True raise BadRequestError( message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 500: exception_mapping_worked = True @@ -4812,7 +5543,7 @@ def exception_type( message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response + response=original_exception.response, ) elif ( "CohereConnectionError" in exception_type @@ -4822,7 +5553,7 @@ def exception_type( message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response + response=original_exception.response, ) elif "invalid type:" in error_str: exception_mapping_worked = True @@ -4830,7 +5561,7 @@ def exception_type( message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response + response=original_exception.response, ) elif "Unexpected server error" in error_str: exception_mapping_worked = True @@ -4838,17 +5569,17 @@ def exception_type( message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - response=original_exception.response + response=original_exception.response, ) else: if hasattr(original_exception, "status_code"): exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"CohereException - {original_exception.message}", llm_provider="cohere", model=model, - request=original_exception.request + request=original_exception.request, ) raise original_exception elif custom_llm_provider == "huggingface": @@ -4858,15 +5589,15 @@ def exception_type( message=error_str, model=model, llm_provider="huggingface", - response=original_exception.response + response=original_exception.response, ) elif "A valid user token is required" in error_str: exception_mapping_worked = True raise BadRequestError( - message=error_str, + message=error_str, llm_provider="huggingface", model=model, - response=original_exception.response + response=original_exception.response, ) if hasattr(original_exception, "status_code"): if original_exception.status_code == 401: @@ -4875,7 +5606,7 @@ def exception_type( message=f"HuggingfaceException - {original_exception.message}", llm_provider="huggingface", model=model, - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 400: exception_mapping_worked = True @@ -4883,7 +5614,7 @@ def exception_type( message=f"HuggingfaceException - {original_exception.message}", model=model, llm_provider="huggingface", - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -4891,7 +5622,7 @@ def exception_type( message=f"HuggingfaceException - {original_exception.message}", model=model, llm_provider="huggingface", - request=original_exception.request + request=original_exception.request, ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -4899,16 +5630,16 @@ def exception_type( message=f"HuggingfaceException - {original_exception.message}", llm_provider="huggingface", model=model, - response=original_exception.response + response=original_exception.response, ) else: exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"HuggingfaceException - {original_exception.message}", llm_provider="huggingface", model=model, - request=original_exception.request + request=original_exception.request, ) elif custom_llm_provider == "ai21": if hasattr(original_exception, "message"): @@ -4918,15 +5649,15 @@ def exception_type( message=f"AI21Exception - {original_exception.message}", model=model, llm_provider="ai21", - response=original_exception.response + response=original_exception.response, ) - if "Bad or missing API token." in original_exception.message: + if "Bad or missing API token." in original_exception.message: exception_mapping_worked = True raise BadRequestError( message=f"AI21Exception - {original_exception.message}", model=model, llm_provider="ai21", - response=original_exception.response + response=original_exception.response, ) if hasattr(original_exception, "status_code"): if original_exception.status_code == 401: @@ -4935,7 +5666,7 @@ def exception_type( message=f"AI21Exception - {original_exception.message}", llm_provider="ai21", model=model, - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -4943,7 +5674,7 @@ def exception_type( message=f"AI21Exception - {original_exception.message}", model=model, llm_provider="ai21", - request=original_exception.request + request=original_exception.request, ) if original_exception.status_code == 422: exception_mapping_worked = True @@ -4951,7 +5682,7 @@ def exception_type( message=f"AI21Exception - {original_exception.message}", model=model, llm_provider="ai21", - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -4959,16 +5690,16 @@ def exception_type( message=f"AI21Exception - {original_exception.message}", llm_provider="ai21", model=model, - response=original_exception.response + response=original_exception.response, ) else: exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"AI21Exception - {original_exception.message}", llm_provider="ai21", model=model, - request=original_exception.request + request=original_exception.request, ) elif custom_llm_provider == "nlp_cloud": if "detail" in error_str: @@ -4978,7 +5709,7 @@ def exception_type( message=f"NLPCloudException - {error_str}", model=model, llm_provider="nlp_cloud", - response=original_exception.response + response=original_exception.response, ) elif "value is not a valid" in error_str: exception_mapping_worked = True @@ -4986,140 +5717,180 @@ def exception_type( message=f"NLPCloudException - {error_str}", model=model, llm_provider="nlp_cloud", - response=original_exception.response + response=original_exception.response, ) - else: + else: exception_mapping_worked = True raise APIError( status_code=500, message=f"NLPCloudException - {error_str}", model=model, llm_provider="nlp_cloud", - request=original_exception.request + request=original_exception.request, ) - if hasattr(original_exception, "status_code"): # https://docs.nlpcloud.com/?shell#errors - if original_exception.status_code == 400 or original_exception.status_code == 406 or original_exception.status_code == 413 or original_exception.status_code == 422: + if hasattr( + original_exception, "status_code" + ): # https://docs.nlpcloud.com/?shell#errors + if ( + original_exception.status_code == 400 + or original_exception.status_code == 406 + or original_exception.status_code == 413 + or original_exception.status_code == 422 + ): exception_mapping_worked = True raise BadRequestError( message=f"NLPCloudException - {original_exception.message}", llm_provider="nlp_cloud", model=model, - response=original_exception.response + response=original_exception.response, ) - elif original_exception.status_code == 401 or original_exception.status_code == 403: + elif ( + original_exception.status_code == 401 + or original_exception.status_code == 403 + ): exception_mapping_worked = True raise AuthenticationError( message=f"NLPCloudException - {original_exception.message}", llm_provider="nlp_cloud", model=model, - response=original_exception.response + response=original_exception.response, ) - elif original_exception.status_code == 522 or original_exception.status_code == 524: + elif ( + original_exception.status_code == 522 + or original_exception.status_code == 524 + ): exception_mapping_worked = True raise Timeout( message=f"NLPCloudException - {original_exception.message}", model=model, llm_provider="nlp_cloud", - request=original_exception.request + request=original_exception.request, ) - elif original_exception.status_code == 429 or original_exception.status_code == 402: + elif ( + original_exception.status_code == 429 + or original_exception.status_code == 402 + ): exception_mapping_worked = True raise RateLimitError( message=f"NLPCloudException - {original_exception.message}", llm_provider="nlp_cloud", model=model, - response=original_exception.response + response=original_exception.response, ) - elif original_exception.status_code == 500 or original_exception.status_code == 503: + elif ( + original_exception.status_code == 500 + or original_exception.status_code == 503 + ): exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"NLPCloudException - {original_exception.message}", llm_provider="nlp_cloud", model=model, - request=original_exception.request + request=original_exception.request, ) - elif original_exception.status_code == 504 or original_exception.status_code == 520: + elif ( + original_exception.status_code == 504 + or original_exception.status_code == 520 + ): exception_mapping_worked = True raise ServiceUnavailableError( message=f"NLPCloudException - {original_exception.message}", model=model, llm_provider="nlp_cloud", - response=original_exception.response + response=original_exception.response, ) else: exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"NLPCloudException - {original_exception.message}", llm_provider="nlp_cloud", model=model, - request=original_exception.request + request=original_exception.request, ) elif custom_llm_provider == "together_ai": import json + try: error_response = json.loads(error_str) except: error_response = {"error": error_str} - if "error" in error_response and "`inputs` tokens + `max_new_tokens` must be <=" in error_response["error"]: + if ( + "error" in error_response + and "`inputs` tokens + `max_new_tokens` must be <=" + in error_response["error"] + ): exception_mapping_worked = True raise ContextWindowExceededError( message=f"TogetherAIException - {error_response['error']}", model=model, llm_provider="together_ai", - response=original_exception.response + response=original_exception.response, ) - elif "error" in error_response and "invalid private key" in error_response["error"]: + elif ( + "error" in error_response + and "invalid private key" in error_response["error"] + ): exception_mapping_worked = True raise AuthenticationError( message=f"TogetherAIException - {error_response['error']}", llm_provider="together_ai", model=model, - response=original_exception.response + response=original_exception.response, ) - elif "error" in error_response and "INVALID_ARGUMENT" in error_response["error"]: + elif ( + "error" in error_response + and "INVALID_ARGUMENT" in error_response["error"] + ): exception_mapping_worked = True raise BadRequestError( message=f"TogetherAIException - {error_response['error']}", model=model, llm_provider="together_ai", - response=original_exception.response + response=original_exception.response, ) - - elif "error" in error_response and "API key doesn't match expected format." in error_response["error"]: + + elif ( + "error" in error_response + and "API key doesn't match expected format." + in error_response["error"] + ): exception_mapping_worked = True raise BadRequestError( message=f"TogetherAIException - {error_response['error']}", model=model, llm_provider="together_ai", - response=original_exception.response + response=original_exception.response, ) - elif "error_type" in error_response and error_response["error_type"] == "validation": + elif ( + "error_type" in error_response + and error_response["error_type"] == "validation" + ): exception_mapping_worked = True raise BadRequestError( message=f"TogetherAIException - {error_response['error']}", model=model, llm_provider="together_ai", - response=original_exception.response + response=original_exception.response, ) if hasattr(original_exception, "status_code"): if original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"TogetherAIException - {original_exception.message}", - model=model, - llm_provider="together_ai", - request=original_exception.request - ) + exception_mapping_worked = True + raise Timeout( + message=f"TogetherAIException - {original_exception.message}", + model=model, + llm_provider="together_ai", + request=original_exception.request, + ) elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"TogetherAIException - {original_exception.message}", - llm_provider="together_ai", - model=model, - response=original_exception.response - ) + exception_mapping_worked = True + raise RateLimitError( + message=f"TogetherAIException - {original_exception.message}", + llm_provider="together_ai", + model=model, + response=original_exception.response, + ) elif original_exception.status_code == 524: exception_mapping_worked = True raise Timeout( @@ -5127,31 +5898,34 @@ def exception_type( llm_provider="together_ai", model=model, ) - else: + else: exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"TogetherAIException - {original_exception.message}", llm_provider="together_ai", model=model, - request=original_exception.request + request=original_exception.request, ) elif custom_llm_provider == "aleph_alpha": - if "This is longer than the model's maximum context length" in error_str: + if ( + "This is longer than the model's maximum context length" + in error_str + ): exception_mapping_worked = True raise ContextWindowExceededError( message=f"AlephAlphaException - {original_exception.message}", - llm_provider="aleph_alpha", + llm_provider="aleph_alpha", model=model, - response=original_exception.response + response=original_exception.response, ) elif "InvalidToken" in error_str or "No token provided" in error_str: exception_mapping_worked = True raise BadRequestError( message=f"AlephAlphaException - {original_exception.message}", - llm_provider="aleph_alpha", + llm_provider="aleph_alpha", model=model, - response=original_exception.response + response=original_exception.response, ) elif hasattr(original_exception, "status_code"): print_verbose(f"status code: {original_exception.status_code}") @@ -5160,7 +5934,7 @@ def exception_type( raise AuthenticationError( message=f"AlephAlphaException - {original_exception.message}", llm_provider="aleph_alpha", - model=model + model=model, ) elif original_exception.status_code == 400: exception_mapping_worked = True @@ -5168,7 +5942,7 @@ def exception_type( message=f"AlephAlphaException - {original_exception.message}", llm_provider="aleph_alpha", model=model, - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -5176,7 +5950,7 @@ def exception_type( message=f"AlephAlphaException - {original_exception.message}", llm_provider="aleph_alpha", model=model, - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 500: exception_mapping_worked = True @@ -5184,33 +5958,35 @@ def exception_type( message=f"AlephAlphaException - {original_exception.message}", llm_provider="aleph_alpha", model=model, - response=original_exception.response + response=original_exception.response, ) raise original_exception raise original_exception elif custom_llm_provider == "ollama": if "no attribute 'async_get_ollama_response_stream" in error_str: exception_mapping_worked = True - raise ImportError("Import error - trying to use async for ollama. import async_generator failed. Try 'pip install async_generator'") + raise ImportError( + "Import error - trying to use async for ollama. import async_generator failed. Try 'pip install async_generator'" + ) if isinstance(original_exception, dict): error_str = original_exception.get("error", "") - else: + else: error_str = str(original_exception) if "no such file or directory" in error_str: exception_mapping_worked = True raise BadRequestError( - message=f"OllamaException: Invalid Model/Model not loaded - {original_exception}", - model=model, - llm_provider="ollama", - response=original_exception.response - ) - elif "Failed to establish a new connection" in error_str: + message=f"OllamaException: Invalid Model/Model not loaded - {original_exception}", + model=model, + llm_provider="ollama", + response=original_exception.response, + ) + elif "Failed to establish a new connection" in error_str: exception_mapping_worked = True raise ServiceUnavailableError( message=f"OllamaException: {original_exception}", - llm_provider="ollama", + llm_provider="ollama", model=model, - response=original_exception.response + response=original_exception.response, ) elif "Invalid response object from API" in error_str: exception_mapping_worked = True @@ -5218,7 +5994,7 @@ def exception_type( message=f"OllamaException: {original_exception}", llm_provider="ollama", model=model, - response=original_exception.response + response=original_exception.response, ) elif custom_llm_provider == "vllm": if hasattr(original_exception, "status_code"): @@ -5228,16 +6004,16 @@ def exception_type( message=f"VLLMException - {original_exception.message}", llm_provider="vllm", model=model, - request=original_exception.request + request=original_exception.request, ) - elif custom_llm_provider == "azure": + elif custom_llm_provider == "azure": if "This model's maximum context length is" in error_str: exception_mapping_worked = True raise ContextWindowExceededError( message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, - response=original_exception.response + response=original_exception.response, ) elif "DeploymentNotFound" in error_str: exception_mapping_worked = True @@ -5245,7 +6021,7 @@ def exception_type( message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, - response=original_exception.response + response=original_exception.response, ) elif "invalid_request_error" in error_str: exception_mapping_worked = True @@ -5253,7 +6029,7 @@ def exception_type( message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, - response=original_exception.response + response=original_exception.response, ) elif hasattr(original_exception, "status_code"): exception_mapping_worked = True @@ -5263,7 +6039,7 @@ def exception_type( message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 408: exception_mapping_worked = True @@ -5271,7 +6047,7 @@ def exception_type( message=f"AzureException - {original_exception.message}", model=model, llm_provider="azure", - request=original_exception.request + request=original_exception.request, ) if original_exception.status_code == 422: exception_mapping_worked = True @@ -5279,7 +6055,7 @@ def exception_type( message=f"AzureException - {original_exception.message}", model=model, llm_provider="azure", - response=original_exception.response + response=original_exception.response, ) elif original_exception.status_code == 429: exception_mapping_worked = True @@ -5287,16 +6063,16 @@ def exception_type( message=f"AzureException - {original_exception.message}", model=model, llm_provider="azure", - response=original_exception.response + response=original_exception.response, ) else: exception_mapping_worked = True raise APIError( - status_code=original_exception.status_code, + status_code=original_exception.status_code, message=f"AzureException - {original_exception.message}", llm_provider="azure", model=model, - request=original_exception.request + request=original_exception.request, ) else: # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors @@ -5304,31 +6080,36 @@ def exception_type( __cause__=original_exception.__cause__, llm_provider="azure", model=model, - request=original_exception.request + request=original_exception.request, ) - if "BadRequestError.__init__() missing 1 required positional argument: 'param'" in str(original_exception): # deal with edge-case invalid request error bug in openai-python sdk + if ( + "BadRequestError.__init__() missing 1 required positional argument: 'param'" + in str(original_exception) + ): # deal with edge-case invalid request error bug in openai-python sdk exception_mapping_worked = True raise BadRequestError( message=f"OpenAIException: This can happen due to missing AZURE_API_VERSION: {str(original_exception)}", - model=model, + model=model, llm_provider=custom_llm_provider, - response=original_exception.response + response=original_exception.response, ) - else: # ensure generic errors always return APIConnectionError= + else: # ensure generic errors always return APIConnectionError= exception_mapping_worked = True if hasattr(original_exception, "request"): raise APIConnectionError( message=f"{str(original_exception)}", llm_provider=custom_llm_provider, model=model, - request=original_exception.request + request=original_exception.request, ) - else: - raise APIConnectionError( + else: + raise APIConnectionError( message=f"{str(original_exception)}", llm_provider=custom_llm_provider, model=model, - request= httpx.Request(method="POST", url="https://api.openai.com/v1/") # stub the request + request=httpx.Request( + method="POST", url="https://api.openai.com/v1/" + ), # stub the request ) except Exception as e: # LOGGING @@ -5362,9 +6143,10 @@ def safe_crash_reporting(model=None, exception=None, custom_llm_provider=None): executor.submit(litellm_telemetry, data) # threading.Thread(target=litellm_telemetry, args=(data,), daemon=True).start() + def get_or_generate_uuid(): temp_dir = os.path.join(os.path.abspath(os.sep), "tmp") - uuid_file = os.path.join(temp_dir, "litellm_uuid.txt") + uuid_file = os.path.join(temp_dir, "litellm_uuid.txt") try: # Try to open the file and load the UUID with open(uuid_file, "r") as file: @@ -5376,19 +6158,19 @@ def get_or_generate_uuid(): except FileNotFoundError: # Generate a new UUID if the file doesn't exist or is empty - try: + try: new_uuid = uuid.uuid4() uuid_value = str(new_uuid) with open(uuid_file, "w") as file: file.write(uuid_value) - except: # if writing to tmp/litellm_uuid.txt then retry writing to litellm_uuid.txt + except: # if writing to tmp/litellm_uuid.txt then retry writing to litellm_uuid.txt try: new_uuid = uuid.uuid4() uuid_value = str(new_uuid) with open("litellm_uuid.txt", "w") as file: file.write(uuid_value) - except: # if this 3rd attempt fails just pass - # Good first issue for someone to improve this function :) + except: # if this 3rd attempt fails just pass + # Good first issue for someone to improve this function :) return except: # [Non-Blocking Error] @@ -5405,17 +6187,13 @@ def litellm_telemetry(data): uuid_value = str(uuid.uuid4()) try: # Prepare the data to send to litellm logging api - try: + try: pkg_version = importlib.metadata.version("litellm") except: pkg_version = None if "model" not in data: data["model"] = None - payload = { - "uuid": uuid_value, - "data": data, - "version:": pkg_version - } + payload = {"uuid": uuid_value, "data": data, "version:": pkg_version} # Make the POST request to litellm logging api response = requests.post( "https://litellm-logging.onrender.com/logging", @@ -5427,29 +6205,33 @@ def litellm_telemetry(data): # [Non-Blocking Error] return + ######### Secret Manager ############################ # checks if user has passed in a secret manager client # if passed in then checks the secret there -def get_secret(secret_name: str, default_value: Optional[str]=None): - if secret_name.startswith("os.environ/"): +def get_secret(secret_name: str, default_value: Optional[str] = None): + if secret_name.startswith("os.environ/"): secret_name = secret_name.replace("os.environ/", "") - try: + try: if litellm.secret_manager_client is not None: try: client = litellm.secret_manager_client - if type(client).__module__ + '.' + type(client).__name__ == 'azure.keyvault.secrets._client.SecretClient': # support Azure Secret Client - from azure.keyvault.secrets import SecretClient + if ( + type(client).__module__ + "." + type(client).__name__ + == "azure.keyvault.secrets._client.SecretClient" + ): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient secret = retrieved_secret = client.get_secret(secret_name).value - else: # assume the default is infisicial client + else: # assume the default is infisicial client secret = client.get_secret(secret_name).secret_value - except: # check if it's in os.environ + except: # check if it's in os.environ secret = os.environ.get(secret_name) return secret else: return os.environ.get(secret_name) - except Exception as e: - if default_value is not None: + except Exception as e: + if default_value is not None: return default_value - else: + else: raise e @@ -5457,7 +6239,9 @@ def get_secret(secret_name: str, default_value: Optional[str]=None): # wraps the completion stream to return the correct format for the model # replicate/anthropic/cohere class CustomStreamWrapper: - def __init__(self, completion_stream, model, custom_llm_provider=None, logging_obj=None): + def __init__( + self, completion_stream, model, custom_llm_provider=None, logging_obj=None + ): self.model = model self.custom_llm_provider = custom_llm_provider self.logging_obj = logging_obj @@ -5465,7 +6249,7 @@ class CustomStreamWrapper: self.sent_first_chunk = False self.sent_last_chunk = False self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "", ""] - self.holding_chunk = "" + self.holding_chunk = "" self.complete_response = "" def __iter__(self): @@ -5474,94 +6258,111 @@ class CustomStreamWrapper: def __aiter__(self): return self - def process_chunk(self, chunk: str): + def process_chunk(self, chunk: str): """ NLP Cloud streaming returns the entire response, for each chunk. Process this, to only return the delta. """ - try: + try: chunk = chunk.strip() self.complete_response = self.complete_response.strip() - if chunk.startswith(self.complete_response): + if chunk.startswith(self.complete_response): # Remove last_sent_chunk only if it appears at the start of the new chunk - chunk = chunk[len(self.complete_response):] + chunk = chunk[len(self.complete_response) :] self.complete_response += chunk - return chunk - except Exception as e: + return chunk + except Exception as e: raise e - - def check_special_tokens(self, chunk: str, finish_reason: Optional[str]): + + def check_special_tokens(self, chunk: str, finish_reason: Optional[str]): hold = False - if finish_reason: - for token in self.special_tokens: + if finish_reason: + for token in self.special_tokens: if token in chunk: - chunk = chunk.replace(token, "") + chunk = chunk.replace(token, "") return hold, chunk - + if self.sent_first_chunk is True: return hold, chunk curr_chunk = self.holding_chunk + chunk curr_chunk = curr_chunk.strip() - for token in self.special_tokens: - if len(curr_chunk) < len(token) and curr_chunk in token: + for token in self.special_tokens: + if len(curr_chunk) < len(token) and curr_chunk in token: hold = True elif len(curr_chunk) >= len(token): if token in curr_chunk: self.holding_chunk = curr_chunk.replace(token, "") hold = True - else: + else: pass - - if hold is False: # reset - self.holding_chunk = "" - return hold, curr_chunk + if hold is False: # reset + self.holding_chunk = "" + return hold, curr_chunk def handle_anthropic_chunk(self, chunk): str_line = chunk.decode("utf-8") # Convert bytes to string - text = "" + text = "" is_finished = False finish_reason = None if str_line.startswith("data:"): data_json = json.loads(str_line[5:]) - text = data_json.get("completion", "") - if data_json.get("stop_reason", None): + text = data_json.get("completion", "") + if data_json.get("stop_reason", None): is_finished = True finish_reason = data_json["stop_reason"] - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } elif "error" in str_line: raise ValueError(f"Unable to parse response. Original response: {str_line}") else: - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } def handle_together_ai_chunk(self, chunk): chunk = chunk.decode("utf-8") - text = "" + text = "" is_finished = False finish_reason = None - if "text" in chunk: + if "text" in chunk: text_index = chunk.find('"text":"') # this checks if text: exists text_start = text_index + len('"text":"') text_end = chunk.find('"}', text_start) if text_index != -1 and text_end != -1: extracted_text = chunk[text_start:text_end] text = extracted_text - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } elif "[DONE]" in chunk: return {"text": text, "is_finished": True, "finish_reason": "stop"} elif "error" in chunk: raise ValueError(chunk) else: - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } def handle_huggingface_chunk(self, chunk): try: if type(chunk) != str: - chunk = chunk.decode("utf-8") # DO NOT REMOVE this: This is required for HF inference API + Streaming - text = "" + chunk = chunk.decode( + "utf-8" + ) # DO NOT REMOVE this: This is required for HF inference API + Streaming + text = "" is_finished = False finish_reason = "" print_verbose(f"chunk: {chunk}") @@ -5570,52 +6371,72 @@ class CustomStreamWrapper: print_verbose(f"data json: {data_json}") if "token" in data_json and "text" in data_json["token"]: text = data_json["token"]["text"] - if data_json.get("details", False) and data_json["details"].get("finish_reason", False): + if data_json.get("details", False) and data_json["details"].get( + "finish_reason", False + ): is_finished = True finish_reason = data_json["details"]["finish_reason"] - elif data_json.get("generated_text", False): # if full generated text exists, then stream is complete - text = "" # don't return the final bos token + elif data_json.get( + "generated_text", False + ): # if full generated text exists, then stream is complete + text = "" # don't return the final bos token is_finished = True finish_reason = "stop" - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} - elif "error" in chunk: + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } + elif "error" in chunk: raise ValueError(chunk) - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} - except Exception as e: + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } + except Exception as e: traceback.print_exc() # raise(e) - - def handle_ai21_chunk(self, chunk): # fake streaming + + def handle_ai21_chunk(self, chunk): # fake streaming chunk = chunk.decode("utf-8") data_json = json.loads(chunk) try: text = data_json["completions"][0]["data"]["text"] is_finished = True finish_reason = "stop" - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } except: raise ValueError(f"Unable to parse response. Original response: {chunk}") - - def handle_maritalk_chunk(self, chunk): # fake streaming + + def handle_maritalk_chunk(self, chunk): # fake streaming chunk = chunk.decode("utf-8") data_json = json.loads(chunk) try: text = data_json["answer"] is_finished = True finish_reason = "stop" - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } except: raise ValueError(f"Unable to parse response. Original response: {chunk}") - + def handle_nlp_cloud_chunk(self, chunk): - text = "" + text = "" is_finished = False finish_reason = "" try: if "dolphin" in self.model: chunk = self.process_chunk(chunk=chunk) - else: + else: data_json = json.loads(chunk) chunk = data_json["generated_text"] text = chunk @@ -5623,10 +6444,14 @@ class CustomStreamWrapper: text = text.replace("[DONE]", "") is_finished = True finish_reason = "stop" - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } except Exception as e: raise ValueError(f"Unable to parse response. Original response: {chunk}") - + def handle_aleph_alpha_chunk(self, chunk): chunk = chunk.decode("utf-8") data_json = json.loads(chunk) @@ -5634,28 +6459,36 @@ class CustomStreamWrapper: text = data_json["completions"][0]["completion"] is_finished = True finish_reason = "stop" - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } except: raise ValueError(f"Unable to parse response. Original response: {chunk}") - + def handle_cohere_chunk(self, chunk): chunk = chunk.decode("utf-8") data_json = json.loads(chunk) try: - text = "" + text = "" is_finished = False finish_reason = "" - if "text" in data_json: + if "text" in data_json: text = data_json["text"] - elif "is_finished" in data_json: + elif "is_finished" in data_json: is_finished = data_json["is_finished"] finish_reason = data_json["finish_reason"] - else: + else: raise Exception(data_json) - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } except: raise ValueError(f"Unable to parse response. Original response: {chunk}") - + def handle_azure_chunk(self, chunk): is_finished = False finish_reason = "" @@ -5665,72 +6498,92 @@ class CustomStreamWrapper: text = "" is_finished = True finish_reason = "stop" - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } elif chunk.startswith("data:"): - data_json = json.loads(chunk[5:]) # chunk.startswith("data:"): + data_json = json.loads(chunk[5:]) # chunk.startswith("data:"): try: - if len(data_json["choices"]) > 0: - text = data_json["choices"][0]["delta"].get("content", "") - if data_json["choices"][0].get("finish_reason", None): + if len(data_json["choices"]) > 0: + text = data_json["choices"][0]["delta"].get("content", "") + if data_json["choices"][0].get("finish_reason", None): is_finished = True finish_reason = data_json["choices"][0]["finish_reason"] - print_verbose(f"text: {text}; is_finished: {is_finished}; finish_reason: {finish_reason}") - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + print_verbose( + f"text: {text}; is_finished: {is_finished}; finish_reason: {finish_reason}" + ) + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } except: - raise ValueError(f"Unable to parse response. Original response: {chunk}") + raise ValueError( + f"Unable to parse response. Original response: {chunk}" + ) elif "error" in chunk: raise ValueError(f"Unable to parse response. Original response: {chunk}") else: - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } def handle_replicate_chunk(self, chunk): try: - text = "" + text = "" is_finished = False finish_reason = "" - if "output" in chunk: - text = chunk['output'] - if "status" in chunk: + if "output" in chunk: + text = chunk["output"] + if "status" in chunk: if chunk["status"] == "succeeded": is_finished = True finish_reason = "stop" - elif chunk.get("error", None): + elif chunk.get("error", None): raise Exception(chunk["error"]) - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } except: raise ValueError(f"Unable to parse response. Original response: {chunk}") - - def handle_openai_chat_completion_chunk(self, chunk): - try: + + def handle_openai_chat_completion_chunk(self, chunk): + try: print_verbose(f"\nRaw OpenAI Chunk\n{chunk}\n") str_line = chunk - text = "" + text = "" is_finished = False finish_reason = None - original_chunk = None # this is used for function/tool calling - if len(str_line.choices) > 0: + original_chunk = None # this is used for function/tool calling + if len(str_line.choices) > 0: if str_line.choices[0].delta.content is not None: text = str_line.choices[0].delta.content - else: # function/tool calling chunk - when content is None. in this case we just return the original chunk from openai + else: # function/tool calling chunk - when content is None. in this case we just return the original chunk from openai original_chunk = str_line if str_line.choices[0].finish_reason: is_finished = True finish_reason = str_line.choices[0].finish_reason return { - "text": text, - "is_finished": is_finished, + "text": text, + "is_finished": is_finished, "finish_reason": finish_reason, - "original_chunk": str_line + "original_chunk": str_line, } except Exception as e: traceback.print_exc() raise e def handle_openai_text_completion_chunk(self, chunk): - try: + try: str_line = chunk - text = "" + text = "" is_finished = False finish_reason = None print_verbose(f"str_line: {str_line}") @@ -5738,20 +6591,36 @@ class CustomStreamWrapper: text = "" is_finished = True finish_reason = "stop" - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } elif str_line.startswith("data:"): data_json = json.loads(str_line[5:]) print_verbose(f"delta content: {data_json}") - text = data_json["choices"][0].get("text", "") - if data_json["choices"][0].get("finish_reason", None): + text = data_json["choices"][0].get("text", "") + if data_json["choices"][0].get("finish_reason", None): is_finished = True finish_reason = data_json["choices"][0]["finish_reason"] - print_verbose(f"text: {text}; is_finished: {is_finished}; finish_reason: {finish_reason}") - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + print_verbose( + f"text: {text}; is_finished: {is_finished}; finish_reason: {finish_reason}" + ) + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } elif "error" in str_line: - raise ValueError(f"Unable to parse response. Original response: {str_line}") + raise ValueError( + f"Unable to parse response. Original response: {str_line}" + ) else: - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } except Exception as e: traceback.print_exc() @@ -5769,14 +6638,22 @@ class CustomStreamWrapper: return "" data_json = json.loads(chunk) if "model_output" in data_json: - if isinstance(data_json["model_output"], dict) and "data" in data_json["model_output"] and isinstance(data_json["model_output"]["data"], list): + if ( + isinstance(data_json["model_output"], dict) + and "data" in data_json["model_output"] + and isinstance(data_json["model_output"]["data"], list) + ): return data_json["model_output"]["data"][0] elif isinstance(data_json["model_output"], str): return data_json["model_output"] - elif "completion" in data_json and isinstance(data_json["completion"], str): + elif "completion" in data_json and isinstance( + data_json["completion"], str + ): return data_json["completion"] else: - raise ValueError(f"Unable to parse response. Original response: {chunk}") + raise ValueError( + f"Unable to parse response. Original response: {chunk}" + ) else: return "" else: @@ -5785,50 +6662,57 @@ class CustomStreamWrapper: traceback.print_exc() return "" - def handle_ollama_stream(self, chunk): - try: + def handle_ollama_stream(self, chunk): + try: json_chunk = json.loads(chunk) - if "error" in json_chunk: + if "error" in json_chunk: raise Exception(f"Ollama Error - {json_chunk}") - - text = "" + + text = "" is_finished = False finish_reason = None if json_chunk["done"] == True: text = "" is_finished = True finish_reason = "stop" - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } elif json_chunk["response"]: print_verbose(f"delta content: {json_chunk}") text = json_chunk["response"] - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} - else: + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } + else: raise Exception(f"Ollama Error - {json_chunk}") - except Exception as e: + except Exception as e: raise e - def handle_bedrock_stream(self, chunk): if hasattr(chunk, "get"): - chunk = chunk.get('chunk') - chunk_data = json.loads(chunk.get('bytes').decode()) + chunk = chunk.get("chunk") + chunk_data = json.loads(chunk.get("bytes").decode()) else: chunk_data = json.loads(chunk.decode()) if chunk_data: - text = "" + text = "" is_finished = False finish_reason = "" - if "outputText" in chunk_data: - text = chunk_data['outputText'] + if "outputText" in chunk_data: + text = chunk_data["outputText"] # ai21 mapping - if "ai21" in self.model: # fake ai21 streaming - text = chunk_data.get('completions')[0].get('data').get('text') + if "ai21" in self.model: # fake ai21 streaming + text = chunk_data.get("completions")[0].get("data").get("text") is_finished = True finish_reason = "stop" # anthropic mapping - elif "completion" in chunk_data: - text = chunk_data['completion'] # bedrock.anthropic + elif "completion" in chunk_data: + text = chunk_data["completion"] # bedrock.anthropic stop_reason = chunk_data.get("stop_reason", None) if stop_reason != None: is_finished = True @@ -5836,22 +6720,26 @@ class CustomStreamWrapper: ######## bedrock.cohere mappings ############### # meta mapping elif "generation" in chunk_data: - text = chunk_data['generation'] # bedrock.meta + text = chunk_data["generation"] # bedrock.meta # cohere mapping elif "text" in chunk_data: - text = chunk_data["text"] # bedrock.cohere + text = chunk_data["text"] # bedrock.cohere # cohere mapping for finish reason elif "finish_reason" in chunk_data: finish_reason = chunk_data["finish_reason"] is_finished = True - elif chunk_data.get("completionReason", None): + elif chunk_data.get("completionReason", None): is_finished = True finish_reason = chunk_data["completionReason"] - elif chunk.get("error", None): + elif chunk.get("error", None): raise Exception(chunk["error"]) - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + } return "" - + def chunk_creator(self, chunk): model_response = ModelResponse(stream=True, model=self.model) model_response.choices = [StreamingChoices()] @@ -5863,62 +6751,83 @@ class CustomStreamWrapper: if self.custom_llm_provider and self.custom_llm_provider == "anthropic": response_obj = self.handle_anthropic_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] elif self.model == "replicate" or self.custom_llm_provider == "replicate": response_obj = self.handle_replicate_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] - elif ( - self.custom_llm_provider and self.custom_llm_provider == "together_ai"): + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] + elif self.custom_llm_provider and self.custom_llm_provider == "together_ai": response_obj = self.handle_together_ai_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] elif self.custom_llm_provider and self.custom_llm_provider == "huggingface": response_obj = self.handle_huggingface_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] - elif self.custom_llm_provider and self.custom_llm_provider == "baseten": # baseten doesn't provide streaming + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] + elif ( + self.custom_llm_provider and self.custom_llm_provider == "baseten" + ): # baseten doesn't provide streaming completion_obj["content"] = self.handle_baseten_chunk(chunk) - elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming + elif ( + self.custom_llm_provider and self.custom_llm_provider == "ai21" + ): # ai21 doesn't provide streaming response_obj = self.handle_ai21_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] elif self.custom_llm_provider and self.custom_llm_provider == "maritalk": response_obj = self.handle_maritalk_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] elif self.custom_llm_provider and self.custom_llm_provider == "vllm": completion_obj["content"] = chunk[0].outputs[0].text - elif self.custom_llm_provider and self.custom_llm_provider == "aleph_alpha": #aleph alpha doesn't provide streaming + elif ( + self.custom_llm_provider and self.custom_llm_provider == "aleph_alpha" + ): # aleph alpha doesn't provide streaming response_obj = self.handle_aleph_alpha_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] elif self.custom_llm_provider == "nlp_cloud": - try: + try: response_obj = self.handle_nlp_cloud_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] except Exception as e: if self.sent_last_chunk: raise e else: - if self.sent_first_chunk is False: + if self.sent_first_chunk is False: raise Exception("An unknown error occurred with the stream") model_response.choices[0].finish_reason = "stop" self.sent_last_chunk = True elif self.custom_llm_provider and self.custom_llm_provider == "vertex_ai": try: # print(chunk) - if hasattr(chunk, 'text'): - # vertexAI chunks return + if hasattr(chunk, "text"): + # vertexAI chunks return # MultiCandidateTextGenerationResponse(text=' ```python\n# This Python code says "Hi" 100 times.\n\n# Create', _prediction_response=Prediction(predictions=[{'candidates': [{'content': ' ```python\n# This Python code says "Hi" 100 times.\n\n# Create', 'author': '1'}], 'citationMetadata': [{'citations': None}], 'safetyAttributes': [{'blocked': False, 'scores': None, 'categories': None}]}], deployed_model_id='', model_version_id=None, model_resource_name=None, explanations=None), is_blocked=False, safety_attributes={}, candidates=[ ```python # This Python code says "Hi" 100 times. # Create]) @@ -5926,28 +6835,32 @@ class CustomStreamWrapper: else: completion_obj["content"] = str(chunk) except StopIteration as e: - if self.sent_last_chunk: - raise e + if self.sent_last_chunk: + raise e else: model_response.choices[0].finish_reason = "stop" self.sent_last_chunk = True elif self.custom_llm_provider == "cohere": response_obj = self.handle_cohere_chunk(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] elif self.custom_llm_provider == "bedrock": - if self.sent_last_chunk: + if self.sent_last_chunk: raise StopIteration response_obj = self.handle_bedrock_stream(chunk) completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] self.sent_last_chunk = True elif self.custom_llm_provider == "sagemaker": print_verbose(f"ENTERS SAGEMAKER STREAMING") - if len(self.completion_stream)==0: - if self.sent_last_chunk: + if len(self.completion_stream) == 0: + if self.sent_last_chunk: raise StopIteration else: model_response.choices[0].finish_reason = "stop" @@ -5955,10 +6868,12 @@ class CustomStreamWrapper: new_chunk = self.completion_stream print_verbose(f"sagemaker chunk: {new_chunk}") completion_obj["content"] = new_chunk - self.completion_stream = self.completion_stream[len(self.completion_stream):] + self.completion_stream = self.completion_stream[ + len(self.completion_stream) : + ] elif self.custom_llm_provider == "petals": - if len(self.completion_stream)==0: - if self.sent_last_chunk: + if len(self.completion_stream) == 0: + if self.sent_last_chunk: raise StopIteration else: model_response.choices[0].finish_reason = "stop" @@ -5971,8 +6886,8 @@ class CustomStreamWrapper: elif self.custom_llm_provider == "palm": # fake streaming response_obj = {} - if len(self.completion_stream)==0: - if self.sent_last_chunk: + if len(self.completion_stream) == 0: + if self.sent_last_chunk: raise StopIteration else: model_response.choices[0].finish_reason = "stop" @@ -5986,33 +6901,50 @@ class CustomStreamWrapper: response_obj = self.handle_ollama_stream(chunk) completion_obj["content"] = response_obj["text"] print_verbose(f"completion obj content: {completion_obj['content']}") - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] elif self.custom_llm_provider == "text-completion-openai": response_obj = self.handle_openai_text_completion_chunk(chunk) completion_obj["content"] = response_obj["text"] print_verbose(f"completion obj content: {completion_obj['content']}") - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] - else: # openai chat model + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] + else: # openai chat model response_obj = self.handle_openai_chat_completion_chunk(chunk) if response_obj == None: return completion_obj["content"] = response_obj["text"] print_verbose(f"completion obj content: {completion_obj['content']}") - if response_obj["is_finished"]: - model_response.choices[0].finish_reason = response_obj["finish_reason"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj[ + "finish_reason" + ] model_response.model = self.model - print_verbose(f"model_response: {model_response}; completion_obj: {completion_obj}") - print_verbose(f"model_response finish reason 3: {model_response.choices[0].finish_reason}") - if len(completion_obj["content"]) > 0: # cannot set content of an OpenAI Object to be an empty string - hold, model_response_str = self.check_special_tokens(chunk=completion_obj["content"], finish_reason=model_response.choices[0].finish_reason) # filter out bos/eos tokens from openai-compatible hf endpoints - print_verbose(f"hold - {hold}, model_response_str - {model_response_str}") - if hold is False: - ## check if openai/azure chunk + print_verbose( + f"model_response: {model_response}; completion_obj: {completion_obj}" + ) + print_verbose( + f"model_response finish reason 3: {model_response.choices[0].finish_reason}" + ) + if ( + len(completion_obj["content"]) > 0 + ): # cannot set content of an OpenAI Object to be an empty string + hold, model_response_str = self.check_special_tokens( + chunk=completion_obj["content"], + finish_reason=model_response.choices[0].finish_reason, + ) # filter out bos/eos tokens from openai-compatible hf endpoints + print_verbose( + f"hold - {hold}, model_response_str - {model_response_str}" + ) + if hold is False: + ## check if openai/azure chunk original_chunk = response_obj.get("original_chunk", None) - if original_chunk: + if original_chunk: model_response.id = original_chunk.id if len(original_chunk.choices) > 0: try: @@ -6020,79 +6952,99 @@ class CustomStreamWrapper: model_response.choices[0].delta = Delta(**delta) except Exception as e: model_response.choices[0].delta = Delta() - else: - return - model_response.system_fingerprint = original_chunk.system_fingerprint + else: + return + model_response.system_fingerprint = ( + original_chunk.system_fingerprint + ) if self.sent_first_chunk == False: model_response.choices[0].delta["role"] = "assistant" self.sent_first_chunk = True - else: - ## else - completion_obj["content"] = model_response_str + else: + ## else + completion_obj["content"] = model_response_str if self.sent_first_chunk == False: completion_obj["role"] = "assistant" self.sent_first_chunk = True model_response.choices[0].delta = Delta(**completion_obj) print_verbose(f"model_response: {model_response}") return model_response - else: - return + else: + return elif model_response.choices[0].finish_reason: - # flush any remaining holding chunk + # flush any remaining holding chunk if len(self.holding_chunk) > 0: if model_response.choices[0].delta.content is None: model_response.choices[0].delta.content = self.holding_chunk else: - model_response.choices[0].delta.content = self.holding_chunk + model_response.choices[0].delta.content - self.holding_chunk = "" - model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai + model_response.choices[0].delta.content = ( + self.holding_chunk + model_response.choices[0].delta.content + ) + self.holding_chunk = "" + model_response.choices[0].finish_reason = map_finish_reason( + model_response.choices[0].finish_reason + ) # ensure consistent output to openai return model_response - elif response_obj is not None and response_obj.get("original_chunk", None) is not None: # function / tool calling branch - only set for openai/azure compatible endpoints + elif ( + response_obj is not None + and response_obj.get("original_chunk", None) is not None + ): # function / tool calling branch - only set for openai/azure compatible endpoints # enter this branch when no content has been passed in response original_chunk = response_obj.get("original_chunk", None) model_response.id = original_chunk.id if len(original_chunk.choices) > 0: - if original_chunk.choices[0].delta.function_call is not None or original_chunk.choices[0].delta.tool_calls is not None: + if ( + original_chunk.choices[0].delta.function_call is not None + or original_chunk.choices[0].delta.tool_calls is not None + ): try: delta = dict(original_chunk.choices[0].delta) model_response.choices[0].delta = Delta(**delta) except Exception as e: model_response.choices[0].delta = Delta() - else: + else: return - else: + else: return model_response.system_fingerprint = original_chunk.system_fingerprint if self.sent_first_chunk == False: model_response.choices[0].delta["role"] = "assistant" self.sent_first_chunk = True return model_response - else: + else: return except StopIteration: raise StopIteration - except Exception as e: + except Exception as e: traceback_exception = traceback.format_exc() e.message = str(e) - raise exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=e) + raise exception_type( + model=self.model, + custom_llm_provider=self.custom_llm_provider, + original_exception=e, + ) ## needs to handle the empty string case (even starting chunk can be an empty string) def __next__(self): try: while True: - if isinstance(self.completion_stream, str) or isinstance(self.completion_stream, bytes): + if isinstance(self.completion_stream, str) or isinstance( + self.completion_stream, bytes + ): chunk = self.completion_stream else: chunk = next(self.completion_stream) print_verbose(f"value of chunk: {chunk} ") - if chunk is not None and chunk != b'': + if chunk is not None and chunk != b"": print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}") response = self.chunk_creator(chunk=chunk) print_verbose(f"PROCESSED CHUNK POST CHUNK CREATOR: {response}") - if response is None: + if response is None: continue ## LOGGING - threading.Thread(target=self.logging_obj.success_handler, args=(response,)).start() # log response + threading.Thread( + target=self.logging_obj.success_handler, args=(response,) + ).start() # log response return response except StopIteration: raise # Re-raise StopIteration @@ -6100,43 +7052,59 @@ class CustomStreamWrapper: print_verbose(f"HITS AN ERROR: {str(e)}\n\n {traceback.format_exc()}") traceback_exception = traceback.format_exc() # LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated - threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start() + threading.Thread( + target=self.logging_obj.failure_handler, args=(e, traceback_exception) + ).start() raise e - - async def __anext__(self): try: - if (self.custom_llm_provider == "openai" + if ( + self.custom_llm_provider == "openai" or self.custom_llm_provider == "azure" or self.custom_llm_provider == "custom_openai" or self.custom_llm_provider == "text-completion-openai" or self.custom_llm_provider == "huggingface" or self.custom_llm_provider == "ollama" - or self.custom_llm_provider == "vertex_ai"): + or self.custom_llm_provider == "vertex_ai" + ): print_verbose(f"INSIDE ASYNC STREAMING!!!") - print_verbose(f"value of async completion stream: {self.completion_stream}") + print_verbose( + f"value of async completion stream: {self.completion_stream}" + ) async for chunk in self.completion_stream: print_verbose(f"value of async chunk: {chunk}") if chunk == "None" or chunk is None: raise Exception - # chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks. + # chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks. # __anext__ also calls async_success_handler, which does logging print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}") - processed_chunk = self.chunk_creator(chunk=chunk) - print_verbose(f"PROCESSED ASYNC CHUNK POST CHUNK CREATOR: {processed_chunk}") - if processed_chunk is None: + processed_chunk = self.chunk_creator(chunk=chunk) + print_verbose( + f"PROCESSED ASYNC CHUNK POST CHUNK CREATOR: {processed_chunk}" + ) + if processed_chunk is None: continue ## LOGGING - threading.Thread(target=self.logging_obj.success_handler, args=(processed_chunk,)).start() # log response - asyncio.create_task(self.logging_obj.async_success_handler(processed_chunk,)) + threading.Thread( + target=self.logging_obj.success_handler, args=(processed_chunk,) + ).start() # log response + asyncio.create_task( + self.logging_obj.async_success_handler( + processed_chunk, + ) + ) return processed_chunk raise StopAsyncIteration - else: # temporary patch for non-aiohttp async calls + else: # temporary patch for non-aiohttp async calls # example - boto3 bedrock llms processed_chunk = next(self) - asyncio.create_task(self.logging_obj.async_success_handler(processed_chunk,)) + asyncio.create_task( + self.logging_obj.async_success_handler( + processed_chunk, + ) + ) return processed_chunk except StopAsyncIteration: raise @@ -6145,9 +7113,12 @@ class CustomStreamWrapper: except Exception as e: traceback_exception = traceback.format_exc() # Handle any exceptions that might occur during streaming - asyncio.create_task(self.logging_obj.async_failure_handler(e, traceback_exception)) + asyncio.create_task( + self.logging_obj.async_failure_handler(e, traceback_exception) + ) raise StopAsyncIteration + class TextCompletionStreamWrapper: def __init__(self, completion_stream, model): self.completion_stream = completion_stream @@ -6158,16 +7129,18 @@ class TextCompletionStreamWrapper: def __aiter__(self): return self - + def convert_to_text_completion_object(self, chunk: ModelResponse): - try: + try: response = TextCompletionResponse() response["id"] = chunk.get("id", None) response["object"] = "text_completion" response["created"] = response.get("created", None) response["model"] = response.get("model", None) text_choices = TextChoices() - if isinstance(chunk, Choices): # chunk should always be of type StreamingChoices + if isinstance( + chunk, Choices + ): # chunk should always be of type StreamingChoices raise Exception text_choices["text"] = chunk["choices"][0]["delta"]["content"] text_choices["index"] = response["choices"][0]["index"] @@ -6175,7 +7148,9 @@ class TextCompletionStreamWrapper: response["choices"] = [text_choices] return response except Exception as e: - raise Exception(f"Error occurred converting to text completion object - chunk: {chunk}; Error: {str(e)}") + raise Exception( + f"Error occurred converting to text completion object - chunk: {chunk}; Error: {str(e)}" + ) def __next__(self): # model_response = ModelResponse(stream=True, model=self.model) @@ -6183,32 +7158,34 @@ class TextCompletionStreamWrapper: try: for chunk in self.completion_stream: if chunk == "None" or chunk is None: - raise Exception - processed_chunk = self.convert_to_text_completion_object(chunk=chunk) + raise Exception + processed_chunk = self.convert_to_text_completion_object(chunk=chunk) return processed_chunk raise StopIteration except StopIteration: raise StopIteration - except Exception as e: - print(f"got exception {e}") # noqa + except Exception as e: + print(f"got exception {e}") # noqa async def __anext__(self): try: async for chunk in self.completion_stream: if chunk == "None" or chunk is None: - raise Exception - processed_chunk = self.convert_to_text_completion_object(chunk=chunk) + raise Exception + processed_chunk = self.convert_to_text_completion_object(chunk=chunk) return processed_chunk raise StopIteration except StopIteration: raise StopAsyncIteration + def mock_completion_streaming_obj(model_response, mock_response, model): for i in range(0, len(mock_response), 3): - completion_obj = {"role": "assistant", "content": mock_response[i: i+3]} + completion_obj = {"role": "assistant", "content": mock_response[i : i + 3]} model_response.choices[0].delta = completion_obj yield model_response + ########## Reading Config File ############################ def read_config_args(config_path) -> dict: try: @@ -6223,23 +7200,25 @@ def read_config_args(config_path) -> dict: except Exception as e: raise e + ########## experimental completion variants ############################ + def completion_with_config(config: Union[dict, str], **kwargs): """ - Generate a litellm.completion() using a config dict and all supported completion args + Generate a litellm.completion() using a config dict and all supported completion args Example config; config = { "default_fallback_models": # [Optional] List of model names to try if a call fails - "available_models": # [Optional] List of all possible models you could call + "available_models": # [Optional] List of all possible models you could call "adapt_to_prompt_size": # [Optional] True/False - if you want to select model based on prompt size (will pick from available_models) "model": { "model-name": { - "needs_moderation": # [Optional] True/False - if you want to call openai moderations endpoint before making completion call. Will raise exception, if flagged. + "needs_moderation": # [Optional] True/False - if you want to call openai moderations endpoint before making completion call. Will raise exception, if flagged. "error_handling": { "error-type": { # One of the errors listed here - https://docs.litellm.ai/docs/exception_mapping#custom-mapping-list - "fallback_model": "" # str, name of the model it should try instead, when that error occurs + "fallback_model": "" # str, name of the model it should try instead, when that error occurs } } } @@ -6263,11 +7242,11 @@ def completion_with_config(config: Union[dict, str], **kwargs): raise Exception("Config path must be a string or a dictionary.") else: raise Exception("Config path not passed in.") - + if config is None: raise Exception("No completion config in the config file") - - models_with_config = config["model"].keys() + + models_with_config = config["model"].keys() model = kwargs["model"] messages = kwargs["messages"] @@ -6278,13 +7257,16 @@ def completion_with_config(config: Union[dict, str], **kwargs): trim_messages_flag = config.get("trim_messages", False) prompt_larger_than_model = False max_model = model - try: + try: max_tokens = litellm.get_max_tokens(model)["max_tokens"] except: - max_tokens = 2048 # assume curr model's max window is 2048 tokens + max_tokens = 2048 # assume curr model's max window is 2048 tokens if adapt_to_prompt_size: - ## Pick model based on token window - prompt_tokens = litellm.token_counter(model="gpt-3.5-turbo", text="".join(message["content"] for message in messages)) + ## Pick model based on token window + prompt_tokens = litellm.token_counter( + model="gpt-3.5-turbo", + text="".join(message["content"] for message in messages), + ) try: curr_max_tokens = litellm.get_max_tokens(model)["max_tokens"] except: @@ -6293,7 +7275,9 @@ def completion_with_config(config: Union[dict, str], **kwargs): prompt_larger_than_model = True for available_model in available_models: try: - curr_max_tokens = litellm.get_max_tokens(available_model)["max_tokens"] + curr_max_tokens = litellm.get_max_tokens(available_model)[ + "max_tokens" + ] if curr_max_tokens > max_tokens: max_tokens = curr_max_tokens max_model = available_model @@ -6307,16 +7291,16 @@ def completion_with_config(config: Union[dict, str], **kwargs): kwargs["messages"] = messages kwargs["model"] = model - try: - if model in models_with_config: + try: + if model in models_with_config: ## Moderation check if config["model"][model].get("needs_moderation"): input = " ".join(message["content"] for message in messages) response = litellm.moderation(input=input) flagged = response["results"][0]["flagged"] - if flagged: + if flagged: raise Exception("This response was flagged as inappropriate") - + ## Model-specific Error Handling error_handling = None if config["model"][model].get("error_handling"): @@ -6328,22 +7312,25 @@ def completion_with_config(config: Union[dict, str], **kwargs): except Exception as e: exception_name = type(e).__name__ fallback_model = None - if error_handling and exception_name in error_handling: + if error_handling and exception_name in error_handling: error_handler = error_handling[exception_name] - # either switch model or api key + # either switch model or api key fallback_model = error_handler.get("fallback_model", None) - if fallback_model: + if fallback_model: kwargs["model"] = fallback_model return litellm.completion(**kwargs) raise e - else: + else: return litellm.completion(**kwargs) except Exception as e: if fallback_models: model = fallback_models.pop(0) - return completion_with_fallbacks(model=model, messages=messages, fallbacks=fallback_models) + return completion_with_fallbacks( + model=model, messages=messages, fallbacks=fallback_models + ) raise e + def completion_with_fallbacks(**kwargs): nested_kwargs = kwargs.pop("kwargs", {}) response = None @@ -6361,8 +7348,10 @@ def completion_with_fallbacks(**kwargs): for model in fallbacks: # loop thru all models try: - # check if it's dict or new model string - if isinstance(model, dict): # completion(model="gpt-4", fallbacks=[{"api_key": "", "api_base": ""}, {"api_key": "", "api_base": ""}]) + # check if it's dict or new model string + if isinstance( + model, dict + ): # completion(model="gpt-4", fallbacks=[{"api_key": "", "api_base": ""}, {"api_key": "", "api_base": ""}]) kwargs["api_key"] = model.get("api_key", None) kwargs["api_base"] = model.get("api_base", None) model = model.get("model", original_model) @@ -6385,7 +7374,10 @@ def completion_with_fallbacks(**kwargs): print_verbose(f"trying to make completion call with model: {model}") kwargs["litellm_call_id"] = litellm_call_id - kwargs = {**kwargs, **nested_kwargs} # combine the openai + litellm params at the same level + kwargs = { + **kwargs, + **nested_kwargs, + } # combine the openai + litellm params at the same level response = litellm.completion(**kwargs, model=model) print_verbose(f"response: {response}") if response != None: @@ -6400,18 +7392,24 @@ def completion_with_fallbacks(**kwargs): pass return response + def process_system_message(system_message, max_tokens, model): system_message_event = {"role": "system", "content": system_message} system_message_tokens = get_token_count([system_message_event], model) if system_message_tokens > max_tokens: - print_verbose("`tokentrimmer`: Warning, system message exceeds token limit. Trimming...") + print_verbose( + "`tokentrimmer`: Warning, system message exceeds token limit. Trimming..." + ) # shorten system message to fit within max_tokens - new_system_message = shorten_message_to_fit_limit(system_message_event, max_tokens, model) + new_system_message = shorten_message_to_fit_limit( + system_message_event, max_tokens, model + ) system_message_tokens = get_token_count([new_system_message], model) - + return system_message_event, max_tokens - system_message_tokens + def process_messages(messages, max_tokens, model): # Process messages from older to more recent messages = messages[::-1] @@ -6422,17 +7420,26 @@ def process_messages(messages, max_tokens, model): available_tokens = max_tokens - used_tokens if available_tokens <= 3: break - final_messages = attempt_message_addition(final_messages=final_messages, message=message, available_tokens=available_tokens, max_tokens=max_tokens, model=model) + final_messages = attempt_message_addition( + final_messages=final_messages, + message=message, + available_tokens=available_tokens, + max_tokens=max_tokens, + model=model, + ) return final_messages -def attempt_message_addition(final_messages, message, available_tokens, max_tokens, model): + +def attempt_message_addition( + final_messages, message, available_tokens, max_tokens, model +): temp_messages = [message] + final_messages temp_message_tokens = get_token_count(messages=temp_messages, model=model) if temp_message_tokens <= max_tokens: return temp_messages - + # if temp_message_tokens > max_tokens, try shortening temp_messages elif "function_call" not in message: # fit updated_message to be within temp_message_tokens - max_tokens (aka the amount temp_message_tokens is greate than max_tokens) @@ -6442,19 +7449,18 @@ def attempt_message_addition(final_messages, message, available_tokens, max_toke return final_messages + def can_add_message(message, messages, max_tokens, model): if get_token_count(messages + [message], model) <= max_tokens: return True return False + def get_token_count(messages, model): return token_counter(model=model, messages=messages) -def shorten_message_to_fit_limit( - message, - tokens_needed, - model): +def shorten_message_to_fit_limit(message, tokens_needed, model): """ Shorten a message to fit within a token limit by removing characters from the middle. """ @@ -6462,7 +7468,7 @@ def shorten_message_to_fit_limit( # For OpenAI models, even blank messages cost 7 token, # and if the buffer is less than 3, the while loop will never end, # hence the value 10. - if 'gpt' in model and tokens_needed <= 10: + if "gpt" in model and tokens_needed <= 10: return message content = message["content"] @@ -6474,21 +7480,22 @@ def shorten_message_to_fit_limit( break ratio = (tokens_needed) / total_tokens - - new_length = int(len(content) * ratio) -1 + + new_length = int(len(content) * ratio) - 1 new_length = max(0, new_length) half_length = new_length // 2 left_half = content[:half_length] right_half = content[-half_length:] - trimmed_content = left_half + '..' + right_half + trimmed_content = left_half + ".." + right_half message["content"] = trimmed_content content = trimmed_content return message -# LiteLLM token trimmer + +# LiteLLM token trimmer # this code is borrowed from https://github.com/KillianLucas/tokentrim/blob/main/tokentrim/tokentrim.py # Credits for this code go to Killian Lucas def trim_messages( @@ -6496,8 +7503,8 @@ def trim_messages( model: Optional[str] = None, trim_ratio: float = 0.75, return_response_tokens: bool = False, - max_tokens = None - ): + max_tokens=None, +): """ Trim a list of messages to fit within a model's token limit. @@ -6519,18 +7526,18 @@ def trim_messages( if max_tokens == None: # Check if model is valid if model in litellm.model_cost: - max_tokens_for_model = litellm.model_cost[model]['max_tokens'] + max_tokens_for_model = litellm.model_cost[model]["max_tokens"] max_tokens = int(max_tokens_for_model * trim_ratio) else: - # if user did not specify max tokens + # if user did not specify max tokens # or passed an llm litellm does not know # do nothing, just return messages - return - - system_message = "" + return + + system_message = "" for message in messages: if message["role"] == "system": - system_message += '\n' if system_message else '' + system_message += "\n" if system_message else "" system_message += message["content"] current_tokens = token_counter(model=model, messages=messages) @@ -6538,38 +7545,47 @@ def trim_messages( # Do nothing if current tokens under messages if current_tokens < max_tokens: - return messages - - #### Trimming messages if current_tokens > max_tokens - print_verbose(f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}") - if system_message: - system_message_event, max_tokens = process_system_message(system_message=system_message, max_tokens=max_tokens, model=model) + return messages - if max_tokens == 0: # the system messages are too long + #### Trimming messages if current_tokens > max_tokens + print_verbose( + f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}" + ) + if system_message: + system_message_event, max_tokens = process_system_message( + system_message=system_message, max_tokens=max_tokens, model=model + ) + + if max_tokens == 0: # the system messages are too long return [system_message_event] - - # Since all system messages are combined and trimmed to fit the max_tokens, + + # Since all system messages are combined and trimmed to fit the max_tokens, # we remove all system messages from the messages list messages = [message for message in messages if message["role"] != "system"] - final_messages = process_messages(messages=messages, max_tokens=max_tokens, model=model) + final_messages = process_messages( + messages=messages, max_tokens=max_tokens, model=model + ) # Add system message to the beginning of the final messages if system_message: final_messages = [system_message_event] + final_messages - if return_response_tokens: # if user wants token count with new trimmed messages + if ( + return_response_tokens + ): # if user wants token count with new trimmed messages response_tokens = max_tokens - get_token_count(final_messages, model) return final_messages, response_tokens return final_messages - except Exception as e: # [NON-Blocking, if error occurs just return final_messages + except Exception as e: # [NON-Blocking, if error occurs just return final_messages print_verbose(f"Got exception while token trimming{e}") return messages + def get_valid_models(): """ Returns a list of valid LLMs based on the set environment variables - + Args: None @@ -6587,13 +7603,13 @@ def get_valid_models(): # edge case litellm has together_ai as a provider, it should be togetherai provider = provider.replace("_", "") - # litellm standardizes expected provider keys to + # litellm standardizes expected provider keys to # PROVIDER_API_KEY. Example: OPENAI_API_KEY, COHERE_API_KEY expected_provider_key = f"{provider.upper()}_API_KEY" - if expected_provider_key in environ_keys: - # key is set + if expected_provider_key in environ_keys: + # key is set valid_providers.append(provider) - + for provider in valid_providers: if provider == "azure": valid_models.append("Azure-LLM") @@ -6602,7 +7618,8 @@ def get_valid_models(): valid_models.extend(models_for_provider) return valid_models except: - return [] # NON-Blocking + return [] # NON-Blocking + # used for litellm.text_completion() to transform HF logprobs to OpenAI.Completion() format def transform_logprobs(hf_response): @@ -6612,40 +7629,39 @@ def transform_logprobs(hf_response): # For each Hugging Face response, transform the logprobs for response in hf_response: # Extract the relevant information from the response - response_details = response['details'] + response_details = response["details"] top_tokens = response_details.get("top_tokens", {}) # Initialize an empty list for the token information token_info = { - 'tokens': [], - 'token_logprobs': [], - 'text_offset': [], - 'top_logprobs': [], + "tokens": [], + "token_logprobs": [], + "text_offset": [], + "top_logprobs": [], } - for i, token in enumerate(response_details['prefill']): + for i, token in enumerate(response_details["prefill"]): # Extract the text of the token - token_text = token['text'] + token_text = token["text"] # Extract the logprob of the token - token_logprob = token['logprob'] + token_logprob = token["logprob"] # Add the token information to the 'token_info' list - token_info['tokens'].append(token_text) - token_info['token_logprobs'].append(token_logprob) + token_info["tokens"].append(token_text) + token_info["token_logprobs"].append(token_logprob) # stub this to work with llm eval harness - top_alt_tokens = { "": -1, "": -2, "": -3 } - token_info['top_logprobs'].append(top_alt_tokens) + top_alt_tokens = {"": -1, "": -2, "": -3} + token_info["top_logprobs"].append(top_alt_tokens) # For each element in the 'tokens' list, extract the relevant information - for i, token in enumerate(response_details['tokens']): - + for i, token in enumerate(response_details["tokens"]): # Extract the text of the token - token_text = token['text'] + token_text = token["text"] # Extract the logprob of the token - token_logprob = token['logprob'] + token_logprob = token["logprob"] top_alt_tokens = {} temp_top_logprobs = [] @@ -6659,13 +7675,15 @@ def transform_logprobs(hf_response): top_alt_tokens[text] = logprob # Add the token information to the 'token_info' list - token_info['tokens'].append(token_text) - token_info['token_logprobs'].append(token_logprob) - token_info['top_logprobs'].append(top_alt_tokens) + token_info["tokens"].append(token_text) + token_info["token_logprobs"].append(token_logprob) + token_info["top_logprobs"].append(top_alt_tokens) # Add the text offset of the token # This is computed as the sum of the lengths of all previous tokens - token_info['text_offset'].append(sum(len(t['text']) for t in response_details['tokens'][:i])) + token_info["text_offset"].append( + sum(len(t["text"]) for t in response_details["tokens"][:i]) + ) # Add the 'token_info' list to the 'transformed_logprobs' list transformed_logprobs = token_info