mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(Refactor) - Re use litellm.completion/litellm.embedding etc for health checks (#7455)
* add mode: realtime * add _realtime_health_check * test_realtime_health_check * azure _realtime_health_check * _realtime_health_check * Realtime Models * fix code quality * delete OAI / Azure custom health check code * simplest version of ahealth check * update tests * working health check post refactor * working aspeech health check * fix realtime health checks * test_audio_transcription_health_check * use get_audio_file_for_health_check * test_text_completion_health_check * ahealth_check * simplify health check code * update ahealth_check * fix import * fix unused imports * fix ahealth_check * fix local testing * test_async_realtime_health_check
This commit is contained in:
parent
4e65722a00
commit
1e06ee3162
9 changed files with 188 additions and 373 deletions
|
@ -1491,132 +1491,3 @@ class AzureChatCompletion(BaseLLM):
|
|||
response["x-ms-region"] = completion.headers["x-ms-region"]
|
||||
|
||||
return response
|
||||
|
||||
async def ahealth_check(
|
||||
self,
|
||||
model: Optional[str],
|
||||
api_key: Optional[str],
|
||||
api_base: str,
|
||||
api_version: Optional[str],
|
||||
timeout: float,
|
||||
mode: str,
|
||||
messages: Optional[list] = None,
|
||||
input: Optional[list] = None,
|
||||
prompt: Optional[str] = None,
|
||||
) -> dict:
|
||||
client_session = (
|
||||
litellm.aclient_session
|
||||
or get_async_httpx_client(llm_provider=LlmProviders.AZURE).client
|
||||
) # handle dall-e-2 calls
|
||||
|
||||
if "gateway.ai.cloudflare.com" in api_base:
|
||||
## build base url - assume api base includes resource name
|
||||
if not api_base.endswith("/"):
|
||||
api_base += "/"
|
||||
api_base += f"{model}"
|
||||
client = AsyncAzureOpenAI(
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
http_client=client_session,
|
||||
)
|
||||
model = None
|
||||
# cloudflare ai gateway, needs model=None
|
||||
else:
|
||||
client = AsyncAzureOpenAI(
|
||||
api_version=api_version,
|
||||
azure_endpoint=api_base,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
http_client=client_session,
|
||||
)
|
||||
|
||||
# only run this check if it's not cloudflare ai gateway
|
||||
if model is None and mode != "image_generation":
|
||||
raise Exception("model is not set")
|
||||
|
||||
completion = None
|
||||
|
||||
if mode == "completion":
|
||||
completion = await client.completions.with_raw_response.create(
|
||||
model=model, # type: ignore
|
||||
prompt=prompt, # type: ignore
|
||||
)
|
||||
elif mode == "chat":
|
||||
if messages is None:
|
||||
raise Exception("messages is not set")
|
||||
completion = await client.chat.completions.with_raw_response.create(
|
||||
model=model, # type: ignore
|
||||
messages=messages, # type: ignore
|
||||
)
|
||||
elif mode == "embedding":
|
||||
if input is None:
|
||||
raise Exception("input is not set")
|
||||
completion = await client.embeddings.with_raw_response.create(
|
||||
model=model, # type: ignore
|
||||
input=input, # type: ignore
|
||||
)
|
||||
elif mode == "image_generation":
|
||||
if prompt is None:
|
||||
raise Exception("prompt is not set")
|
||||
completion = await client.images.with_raw_response.generate(
|
||||
model=model, # type: ignore
|
||||
prompt=prompt, # type: ignore
|
||||
)
|
||||
elif mode == "audio_transcription":
|
||||
# Get the current directory of the file being run
|
||||
pwd = os.path.dirname(os.path.realpath(__file__))
|
||||
file_path = os.path.join(
|
||||
pwd, "../../../tests/gettysburg.wav"
|
||||
) # proxy address
|
||||
audio_file = open(file_path, "rb")
|
||||
completion = await client.audio.transcriptions.with_raw_response.create(
|
||||
file=audio_file,
|
||||
model=model, # type: ignore
|
||||
prompt=prompt, # type: ignore
|
||||
)
|
||||
elif mode == "audio_speech":
|
||||
# Get the current directory of the file being run
|
||||
completion = await client.audio.speech.with_raw_response.create(
|
||||
model=model, # type: ignore
|
||||
input=prompt, # type: ignore
|
||||
voice="alloy",
|
||||
)
|
||||
elif mode == "batch":
|
||||
completion = await client.batches.with_raw_response.list(limit=1) # type: ignore
|
||||
elif mode == "realtime":
|
||||
from litellm.realtime_api.main import _realtime_health_check
|
||||
|
||||
# create a websocket connection
|
||||
await _realtime_health_check(
|
||||
model=model or "",
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
custom_llm_provider="azure",
|
||||
)
|
||||
return {}
|
||||
else:
|
||||
raise Exception("mode not set")
|
||||
response = {}
|
||||
|
||||
if completion is None or not hasattr(completion, "headers"):
|
||||
raise Exception("invalid completion response")
|
||||
|
||||
if (
|
||||
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
|
||||
): # not provided for dall-e requests
|
||||
response["x-ratelimit-remaining-requests"] = completion.headers[
|
||||
"x-ratelimit-remaining-requests"
|
||||
]
|
||||
|
||||
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
|
||||
response["x-ratelimit-remaining-tokens"] = completion.headers[
|
||||
"x-ratelimit-remaining-tokens"
|
||||
]
|
||||
|
||||
if completion.headers.get("x-ms-region", None) is not None:
|
||||
response["x-ms-region"] = completion.headers["x-ms-region"]
|
||||
|
||||
return response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue