fix(main.py): keep client consistent across calls + exponential backoff retry on ratelimit errors

This commit is contained in:
Krrish Dholakia 2023-11-14 16:25:36 -08:00
parent 5963d9d283
commit a7222f257c
9 changed files with 239 additions and 131 deletions

View file

@ -154,10 +154,12 @@ class OpenAITextCompletionConfig():
class OpenAIChatCompletion(BaseLLM):
_client_session: httpx.Client
_aclient_session: httpx.AsyncClient
def __init__(self) -> None:
super().__init__()
self._client_session = self.create_client_session()
self._aclient_session = self.create_aclient_session()
def validate_environment(self, api_key):
headers = {
@ -251,15 +253,15 @@ class OpenAIChatCompletion(BaseLLM):
api_base: str,
data: dict, headers: dict,
model_response: ModelResponse):
async with httpx.AsyncClient(timeout=600) as client:
response = await client.post(api_base, json=data, headers=headers)
response_json = response.json()
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response)
client = self._aclient_session
## RESPONSE OBJECT
return convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
response = await client.post(api_base, json=data, headers=headers)
response_json = response.json()
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response)
## RESPONSE OBJECT
return convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
def streaming(self,
logging_obj,
@ -290,8 +292,7 @@ class OpenAIChatCompletion(BaseLLM):
headers: dict,
model_response: ModelResponse,
model: str):
client = httpx.AsyncClient()
async with client.stream(
async with self._aclient_session.stream(
url=f"{api_base}",
json=data,
headers=headers,