(feat+test) use passed OpenAI client

This commit is contained in:
ishaan-jaff 2023-11-28 16:09:10 -08:00
parent 8ac7801283
commit f4a7760ea1
2 changed files with 59 additions and 17 deletions

View file

@ -174,7 +174,9 @@ class OpenAIChatCompletion(BaseLLM):
litellm_params=None,
logger_fn=None,
headers: Optional[dict]=None,
custom_prompt_dict: dict={}):
custom_prompt_dict: dict={},
client=None
):
super().completion()
exception_mapping_worked = False
try:
@ -203,16 +205,19 @@ class OpenAIChatCompletion(BaseLLM):
try:
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout)
return self.async_streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client)
else:
return self.acompletion(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout)
return self.acompletion(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client)
elif optional_params.get("stream", False):
return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout)
return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client)
else:
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int")
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
if client is None:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
else:
openai_client = client
response = openai_client.chat.completions.create(**data) # type: ignore
logging_obj.post_call(
input=None,
@ -251,10 +256,15 @@ class OpenAIChatCompletion(BaseLLM):
model_response: ModelResponse,
timeout: float,
api_key: Optional[str]=None,
api_base: Optional[str]=None):
api_base: Optional[str]=None,
client=None
):
response = None
try:
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=data.pop("max_retries", 2))
if client is None:
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=data.pop("max_retries", 2))
else:
openai_aclient = client
response = await openai_aclient.chat.completions.create(**data)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
except Exception as e:
@ -272,9 +282,13 @@ class OpenAIChatCompletion(BaseLLM):
data: dict,
model: str,
api_key: Optional[str]=None,
api_base: Optional[str]=None
api_base: Optional[str]=None,
client = None,
):
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=data.pop("max_retries", 2))
if client is None:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=data.pop("max_retries", 2))
else:
openai_client = client
response = openai_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
for transformed_chunk in streamwrapper:
@ -286,10 +300,14 @@ class OpenAIChatCompletion(BaseLLM):
data: dict,
model: str,
api_key: Optional[str]=None,
api_base: Optional[str]=None):
api_base: Optional[str]=None,
client=None):
response = None
try:
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=data.pop("max_retries", 2))
if client is None:
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=data.pop("max_retries", 2))
else:
openai_aclient = client
response = await openai_aclient.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
@ -312,6 +330,7 @@ class OpenAIChatCompletion(BaseLLM):
model_response: Optional[litellm.utils.EmbeddingResponse] = None,
logging_obj=None,
optional_params=None,
client=None,
):
super().embedding()
exception_mapping_worked = False
@ -325,8 +344,10 @@ class OpenAIChatCompletion(BaseLLM):
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int")
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, max_retries=max_retries, timeout=timeout)
if client is None:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
else:
openai_client = client
## LOGGING
logging_obj.pre_call(
input=input,