From f4a7760ea14915f5e0ff170d5d1cef81cfd24f7a Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 28 Nov 2023 16:09:10 -0800 Subject: [PATCH] (feat+test) use passed OpenAI client --- litellm/llms/openai.py | 47 +++++++++++++++++++++++--------- litellm/tests/test_completion.py | 29 +++++++++++++++++--- 2 files changed, 59 insertions(+), 17 deletions(-) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 6d0809442c..ff3d7f0e4e 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -174,7 +174,9 @@ class OpenAIChatCompletion(BaseLLM): litellm_params=None, logger_fn=None, headers: Optional[dict]=None, - custom_prompt_dict: dict={}): + custom_prompt_dict: dict={}, + client=None + ): super().completion() exception_mapping_worked = False try: @@ -203,16 +205,19 @@ class OpenAIChatCompletion(BaseLLM): try: if acompletion is True: if optional_params.get("stream", False): - return self.async_streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout) + return self.async_streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client) else: - return self.acompletion(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout) + return self.acompletion(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client) elif optional_params.get("stream", False): - return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout) + return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client) else: max_retries = data.pop("max_retries", 2) if not isinstance(max_retries, int): raise OpenAIError(status_code=422, message="max retries must be an int") - openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) + if client is None: + openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) + else: + openai_client = client response = openai_client.chat.completions.create(**data) # type: ignore logging_obj.post_call( input=None, @@ -251,10 +256,15 @@ class OpenAIChatCompletion(BaseLLM): model_response: ModelResponse, timeout: float, api_key: Optional[str]=None, - api_base: Optional[str]=None): + api_base: Optional[str]=None, + client=None + ): response = None try: - openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=data.pop("max_retries", 2)) + if client is None: + openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=data.pop("max_retries", 2)) + else: + openai_aclient = client response = await openai_aclient.chat.completions.create(**data) return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) except Exception as e: @@ -272,9 +282,13 @@ class OpenAIChatCompletion(BaseLLM): data: dict, model: str, api_key: Optional[str]=None, - api_base: Optional[str]=None + api_base: Optional[str]=None, + client = None, ): - openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=data.pop("max_retries", 2)) + if client is None: + openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=data.pop("max_retries", 2)) + else: + openai_client = client response = openai_client.chat.completions.create(**data) streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj) for transformed_chunk in streamwrapper: @@ -286,10 +300,14 @@ class OpenAIChatCompletion(BaseLLM): data: dict, model: str, api_key: Optional[str]=None, - api_base: Optional[str]=None): + api_base: Optional[str]=None, + client=None): response = None try: - openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=data.pop("max_retries", 2)) + if client is None: + openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=data.pop("max_retries", 2)) + else: + openai_aclient = client response = await openai_aclient.chat.completions.create(**data) streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj) async for transformed_chunk in streamwrapper: @@ -312,6 +330,7 @@ class OpenAIChatCompletion(BaseLLM): model_response: Optional[litellm.utils.EmbeddingResponse] = None, logging_obj=None, optional_params=None, + client=None, ): super().embedding() exception_mapping_worked = False @@ -325,8 +344,10 @@ class OpenAIChatCompletion(BaseLLM): max_retries = data.pop("max_retries", 2) if not isinstance(max_retries, int): raise OpenAIError(status_code=422, message="max retries must be an int") - openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, max_retries=max_retries, timeout=timeout) - + if client is None: + openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) + else: + openai_client = client ## LOGGING logging_obj.pre_call( input=input, diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 30bedbe7fb..ec4abf113b 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -606,10 +606,31 @@ async def test_re_use_azure_async_client(): except Exception as e: pytest.fail("got Exception", e) -import asyncio -asyncio.run( - test_re_use_azure_async_client() -) +# import asyncio +# asyncio.run( +# test_re_use_azure_async_client() +# ) + + +def test_re_use_openaiClient(): + try: + print("gpt-3.5 with client test\n\n") + litellm.set_verbose=True + import openai + client = openai.OpenAI( + api_key=os.environ["OPENAI_API_KEY"], + ) + ## Test OpenAI call + for _ in range(2): + response = litellm.completion( + model="gpt-3.5-turbo", + messages=messages, + client=client + ) + print(f"response: {response}") + except Exception as e: + pytest.fail("got Exception", e) +test_re_use_openaiClient() def test_completion_azure(): try: