(Refactor) - Re use litellm.completion/litellm.embedding etc for health checks (#7455)

* add mode: realtime

* add _realtime_health_check

* test_realtime_health_check

* azure _realtime_health_check

* _realtime_health_check

* Realtime Models

* fix code quality

* delete OAI / Azure custom health check code

* simplest version of ahealth check

* update tests

* working health check post refactor

* working aspeech health check

* fix realtime health checks

* test_audio_transcription_health_check

* use get_audio_file_for_health_check

* test_text_completion_health_check

* ahealth_check

* simplify health check code

* update ahealth_check

* fix import

* fix unused imports

* fix ahealth_check

* fix local testing

* test_async_realtime_health_check
This commit is contained in:
Ishaan Jaff 2024-12-28 18:38:54 -08:00 committed by GitHub
parent 4e65722a00
commit 1e06ee3162
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 188 additions and 373 deletions

View file

@ -1491,132 +1491,3 @@ class AzureChatCompletion(BaseLLM):
response["x-ms-region"] = completion.headers["x-ms-region"]
return response
async def ahealth_check(
self,
model: Optional[str],
api_key: Optional[str],
api_base: str,
api_version: Optional[str],
timeout: float,
mode: str,
messages: Optional[list] = None,
input: Optional[list] = None,
prompt: Optional[str] = None,
) -> dict:
client_session = (
litellm.aclient_session
or get_async_httpx_client(llm_provider=LlmProviders.AZURE).client
) # handle dall-e-2 calls
if "gateway.ai.cloudflare.com" in api_base:
## build base url - assume api base includes resource name
if not api_base.endswith("/"):
api_base += "/"
api_base += f"{model}"
client = AsyncAzureOpenAI(
base_url=api_base,
api_version=api_version,
api_key=api_key,
timeout=timeout,
http_client=client_session,
)
model = None
# cloudflare ai gateway, needs model=None
else:
client = AsyncAzureOpenAI(
api_version=api_version,
azure_endpoint=api_base,
api_key=api_key,
timeout=timeout,
http_client=client_session,
)
# only run this check if it's not cloudflare ai gateway
if model is None and mode != "image_generation":
raise Exception("model is not set")
completion = None
if mode == "completion":
completion = await client.completions.with_raw_response.create(
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "chat":
if messages is None:
raise Exception("messages is not set")
completion = await client.chat.completions.with_raw_response.create(
model=model, # type: ignore
messages=messages, # type: ignore
)
elif mode == "embedding":
if input is None:
raise Exception("input is not set")
completion = await client.embeddings.with_raw_response.create(
model=model, # type: ignore
input=input, # type: ignore
)
elif mode == "image_generation":
if prompt is None:
raise Exception("prompt is not set")
completion = await client.images.with_raw_response.generate(
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_transcription":
# Get the current directory of the file being run
pwd = os.path.dirname(os.path.realpath(__file__))
file_path = os.path.join(
pwd, "../../../tests/gettysburg.wav"
) # proxy address
audio_file = open(file_path, "rb")
completion = await client.audio.transcriptions.with_raw_response.create(
file=audio_file,
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_speech":
# Get the current directory of the file being run
completion = await client.audio.speech.with_raw_response.create(
model=model, # type: ignore
input=prompt, # type: ignore
voice="alloy",
)
elif mode == "batch":
completion = await client.batches.with_raw_response.list(limit=1) # type: ignore
elif mode == "realtime":
from litellm.realtime_api.main import _realtime_health_check
# create a websocket connection
await _realtime_health_check(
model=model or "",
api_key=api_key,
api_base=api_base,
api_version=api_version,
custom_llm_provider="azure",
)
return {}
else:
raise Exception("mode not set")
response = {}
if completion is None or not hasattr(completion, "headers"):
raise Exception("invalid completion response")
if (
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
): # not provided for dall-e requests
response["x-ratelimit-remaining-requests"] = completion.headers[
"x-ratelimit-remaining-requests"
]
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
response["x-ratelimit-remaining-tokens"] = completion.headers[
"x-ratelimit-remaining-tokens"
]
if completion.headers.get("x-ms-region", None) is not None:
response["x-ms-region"] = completion.headers["x-ms-region"]
return response