From 8609694b4963352618baf795795d9e0fb30692d8 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 28 Nov 2023 17:09:58 -0800 Subject: [PATCH] (fix) completion:openai-pop out max_retries from completion kwargs --- litellm/llms/openai.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index ff3d7f0e4e..41570f0b5c 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -203,15 +203,15 @@ class OpenAIChatCompletion(BaseLLM): ) try: + max_retries = data.pop("max_retries", 2) if acompletion is True: if optional_params.get("stream", False): - return self.async_streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client) + return self.async_streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) else: - return self.acompletion(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client) + return self.acompletion(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) elif optional_params.get("stream", False): - return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client) + return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) else: - max_retries = data.pop("max_retries", 2) if not isinstance(max_retries, int): raise OpenAIError(status_code=422, message="max retries must be an int") if client is None: @@ -257,12 +257,13 @@ class OpenAIChatCompletion(BaseLLM): timeout: float, api_key: Optional[str]=None, api_base: Optional[str]=None, - client=None + client=None, + max_retries=None, ): response = None try: if client is None: - openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=data.pop("max_retries", 2)) + openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries) else: openai_aclient = client response = await openai_aclient.chat.completions.create(**data) @@ -284,9 +285,10 @@ class OpenAIChatCompletion(BaseLLM): api_key: Optional[str]=None, api_base: Optional[str]=None, client = None, + max_retries=None ): if client is None: - openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=data.pop("max_retries", 2)) + openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) else: openai_client = client response = openai_client.chat.completions.create(**data) @@ -301,11 +303,13 @@ class OpenAIChatCompletion(BaseLLM): model: str, api_key: Optional[str]=None, api_base: Optional[str]=None, - client=None): + client=None, + max_retries=None, + ): response = None try: if client is None: - openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=data.pop("max_retries", 2)) + openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries) else: openai_aclient = client response = await openai_aclient.chat.completions.create(**data)