From a7f182b8ecb6e4456d8e69eae1478b67c43e7e0e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 12 Jan 2024 00:13:01 +0530 Subject: [PATCH] fix(azure.py): support health checks to text completion endpoints --- litellm/llms/azure.py | 5 +++++ litellm/llms/openai.py | 20 +++++++++++--------- litellm/main.py | 5 ++++- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 716b65dbb..c2e1e510b 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -753,6 +753,11 @@ class AzureChatCompletion(BaseLLM): completion = None if mode == "completion": + completion = await client.completions.with_raw_response.create( + model=model, # type: ignore + prompt=prompt, # type: ignore + ) + elif mode == "chat": if messages is None: raise Exception("messages is not set") completion = await client.chat.completions.with_raw_response.create( diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 0ae5633d5..9285bf6f5 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -741,6 +741,11 @@ class OpenAIChatCompletion(BaseLLM): completion = None if mode == "completion": + completion = await client.completions.with_raw_response.create( + model=model, # type: ignore + prompt=prompt, # type: ignore + ) + elif mode == "chat": if messages is None: raise Exception("messages is not set") completion = await client.chat.completions.with_raw_response.create( @@ -889,7 +894,7 @@ class OpenAITextCompletion(BaseLLM): headers=headers, model_response=model_response, model=model, - timeout=timeout + timeout=timeout, ) else: return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout) # type: ignore @@ -901,14 +906,11 @@ class OpenAITextCompletion(BaseLLM): headers=headers, model_response=model_response, model=model, - timeout=timeout + timeout=timeout, ) else: response = httpx.post( - url=f"{api_base}", - json=data, - headers=headers, - timeout=timeout + url=f"{api_base}", json=data, headers=headers, timeout=timeout ) if response.status_code != 200: raise OpenAIError( @@ -944,7 +946,7 @@ class OpenAITextCompletion(BaseLLM): prompt: str, api_key: str, model: str, - timeout: float + timeout: float, ): async with httpx.AsyncClient(timeout=timeout) as client: try: @@ -986,7 +988,7 @@ class OpenAITextCompletion(BaseLLM): headers: dict, model_response: ModelResponse, model: str, - timeout: float + timeout: float, ): with httpx.stream( url=f"{api_base}", @@ -1017,7 +1019,7 @@ class OpenAITextCompletion(BaseLLM): headers: dict, model_response: ModelResponse, model: str, - timeout: float + timeout: float, ): client = httpx.AsyncClient() async with client.stream( diff --git a/litellm/main.py b/litellm/main.py index 70264b312..e9b0f9955 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3123,7 +3123,10 @@ async def ahealth_check( prompt=prompt, input=input, ) - elif custom_llm_provider == "openai": + elif ( + custom_llm_provider == "openai" + or custom_llm_provider == "text-completion-openai" + ): api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY") timeout = (