diff --git a/litellm/exceptions.py b/litellm/exceptions.py index 4ebc6cd17..d3efb60e7 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -134,8 +134,6 @@ class APIResponseValidationError(APIResponseValidationError): # type: ignore message=message ) - - class OpenAIError(OpenAIError): # type: ignore def __init__(self, original_exception): self.status_code = original_exception.http_status diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index d4386e0bb..f6c8ae9a4 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -136,7 +136,10 @@ class AzureChatCompletion(BaseLLM): elif "stream" in optional_params and optional_params["stream"] == True: return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout) else: - azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, timeout=timeout, max_retries=data.pop("max_retries", 2)) + max_retries = data.pop("max_retries", 2) + if not isinstance(max_retries, int): + raise AzureOpenAIError(status_code=422, message="max retries must be an int") + azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) response = azure_client.chat.completions.create(**data) # type: ignore return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) except AzureOpenAIError as e: @@ -156,7 +159,10 @@ class AzureChatCompletion(BaseLLM): azure_ad_token: Optional[str]=None, ): response = None try: - azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session, timeout=timeout, max_retries=data.pop("max_retries", 2)) + max_retries = data.pop("max_retries", 2) + if not isinstance(max_retries, int): + raise AzureOpenAIError(status_code=422, message="max retries must be an int") + azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries) response = await azure_client.chat.completions.create(**data) return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) except Exception as e: @@ -177,7 +183,10 @@ class AzureChatCompletion(BaseLLM): timeout: Any, azure_ad_token: Optional[str]=None, ): - azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, timeout=timeout, max_retries=data.pop("max_retries", 2)) + max_retries = data.pop("max_retries", 2) + if not isinstance(max_retries, int): + raise AzureOpenAIError(status_code=422, message="max retries must be an int") + azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) response = azure_client.chat.completions.create(**data) streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj) for transformed_chunk in streamwrapper: @@ -218,7 +227,10 @@ class AzureChatCompletion(BaseLLM): "input": input, **optional_params } - azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, max_retries=data.pop("max_retries", 2)) + max_retries = data.pop("max_retries", 2) + if not isinstance(max_retries, int): + raise AzureOpenAIError(status_code=422, message="max retries must be an int") + azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, max_retries=max_retries) ## LOGGING logging_obj.pre_call( diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 4a164bbe2..1a680744a 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -209,7 +209,10 @@ class OpenAIChatCompletion(BaseLLM): elif optional_params.get("stream", False): return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout) else: - openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=data.pop("max_retries", 2)) + max_retries = data.pop("max_retries", 2) + if not isinstance(max_retries, int): + raise OpenAIError(status_code=422, message="max retries must be an int") + openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) response = openai_client.chat.completions.create(**data) # type: ignore logging_obj.post_call( input=None, @@ -317,7 +320,10 @@ class OpenAIChatCompletion(BaseLLM): "input": input, **optional_params } - openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, max_retries=data.pop("max_retries", 2)) + max_retries = data.pop("max_retries", 2) + if not isinstance(max_retries, int): + raise OpenAIError(status_code=422, message="max retries must be an int") + openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, max_retries=max_retries) ## LOGGING logging_obj.pre_call( diff --git a/litellm/utils.py b/litellm/utils.py index 0925c4b20..b47d9554e 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1082,7 +1082,7 @@ class Rules: if callable(rule): decision = rule(input) if decision is False: - raise litellm.APIResponseValidationError("LLM Response failed post-call-rule check", llm_provider=custom_llm_provider, model=model) + raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider=custom_llm_provider, model=model) # type: ignore return True def post_call_rules(self, input: str, model: str): @@ -1091,7 +1091,7 @@ class Rules: if callable(rule): decision = rule(input) if decision is False: - raise litellm.APIResponseValidationError("LLM Response failed post-call-rule check", llm_provider=custom_llm_provider, model=model) + raise litellm.APIResponseValidationError(message="LLM Response failed post-call-rule check", llm_provider=custom_llm_provider, model=model) # type: ignore return True ####### CLIENT ################### @@ -3135,10 +3135,10 @@ def convert_to_model_response_object(response_object: Optional[dict]=None, model choice_list.append(choice) model_response_object.choices = choice_list - if "usage" in response_object: - model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) - model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) - model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) + if "usage" in response_object and response_object["usage"] is not None: + model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore + model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore + model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore if "id" in response_object: model_response_object.id = response_object["id"]