fix(azure.py): support health checks to text completion endpoints

This commit is contained in:
Krrish Dholakia 2024-01-12 00:13:01 +05:30
parent 0a76269541
commit a7f182b8ec
3 changed files with 20 additions and 10 deletions

View file

@ -753,6 +753,11 @@ class AzureChatCompletion(BaseLLM):
completion = None completion = None
if mode == "completion": if mode == "completion":
completion = await client.completions.with_raw_response.create(
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "chat":
if messages is None: if messages is None:
raise Exception("messages is not set") raise Exception("messages is not set")
completion = await client.chat.completions.with_raw_response.create( completion = await client.chat.completions.with_raw_response.create(

View file

@ -741,6 +741,11 @@ class OpenAIChatCompletion(BaseLLM):
completion = None completion = None
if mode == "completion": if mode == "completion":
completion = await client.completions.with_raw_response.create(
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "chat":
if messages is None: if messages is None:
raise Exception("messages is not set") raise Exception("messages is not set")
completion = await client.chat.completions.with_raw_response.create( completion = await client.chat.completions.with_raw_response.create(
@ -889,7 +894,7 @@ class OpenAITextCompletion(BaseLLM):
headers=headers, headers=headers,
model_response=model_response, model_response=model_response,
model=model, model=model,
timeout=timeout timeout=timeout,
) )
else: else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout) # type: ignore return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout) # type: ignore
@ -901,14 +906,11 @@ class OpenAITextCompletion(BaseLLM):
headers=headers, headers=headers,
model_response=model_response, model_response=model_response,
model=model, model=model,
timeout=timeout timeout=timeout,
) )
else: else:
response = httpx.post( response = httpx.post(
url=f"{api_base}", url=f"{api_base}", json=data, headers=headers, timeout=timeout
json=data,
headers=headers,
timeout=timeout
) )
if response.status_code != 200: if response.status_code != 200:
raise OpenAIError( raise OpenAIError(
@ -944,7 +946,7 @@ class OpenAITextCompletion(BaseLLM):
prompt: str, prompt: str,
api_key: str, api_key: str,
model: str, model: str,
timeout: float timeout: float,
): ):
async with httpx.AsyncClient(timeout=timeout) as client: async with httpx.AsyncClient(timeout=timeout) as client:
try: try:
@ -986,7 +988,7 @@ class OpenAITextCompletion(BaseLLM):
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str, model: str,
timeout: float timeout: float,
): ):
with httpx.stream( with httpx.stream(
url=f"{api_base}", url=f"{api_base}",
@ -1017,7 +1019,7 @@ class OpenAITextCompletion(BaseLLM):
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str, model: str,
timeout: float timeout: float,
): ):
client = httpx.AsyncClient() client = httpx.AsyncClient()
async with client.stream( async with client.stream(

View file

@ -3123,7 +3123,10 @@ async def ahealth_check(
prompt=prompt, prompt=prompt,
input=input, input=input,
) )
elif custom_llm_provider == "openai": elif (
custom_llm_provider == "openai"
or custom_llm_provider == "text-completion-openai"
):
api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY") api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY")
timeout = ( timeout = (