fix(router.py): support retry and fallbacks for atext_completion

This commit is contained in:
Krrish Dholakia 2023-12-30 11:19:13 +05:30
parent 7ecd7b3e8d
commit 38f55249e1
6 changed files with 290 additions and 69 deletions

View file

@ -479,6 +479,8 @@ class EmbeddingResponse(OpenAIObject):
usage: Optional[Usage] = None
"""Usage statistics for the embedding request."""
_hidden_params: dict = {}
def __init__(
self, model=None, usage=None, stream=False, response_ms=None, data=None
):
@ -640,6 +642,8 @@ class ImageResponse(OpenAIObject):
usage: Optional[dict] = None
_hidden_params: dict = {}
def __init__(self, created=None, data=None, response_ms=None):
if response_ms:
_response_ms = response_ms
@ -2053,6 +2057,10 @@ def client(original_function):
target=logging_obj.success_handler, args=(result, start_time, end_time)
).start()
# RETURN RESULT
if hasattr(result, "_hidden_params"):
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
"id", None
)
result._response_ms = (
end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai
@ -2273,6 +2281,10 @@ def client(original_function):
target=logging_obj.success_handler, args=(result, start_time, end_time)
).start()
# RETURN RESULT
if hasattr(result, "_hidden_params"):
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
"id", None
)
if isinstance(result, ModelResponse):
result._response_ms = (
end_time - start_time
@ -6527,6 +6539,13 @@ class CustomStreamWrapper:
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
self.holding_chunk = ""
self.complete_response = ""
self._hidden_params = {
"model_id": (
self.logging_obj.model_call_details.get("litellm_params", {})
.get("model_info", {})
.get("id", None)
)
} # returned as x-litellm-model-id response header in proxy
def __iter__(self):
return self
@ -7417,6 +7436,15 @@ class CustomStreamWrapper:
threading.Thread(
target=self.logging_obj.success_handler, args=(response,)
).start() # log response
# RETURN RESULT
if hasattr(response, "_hidden_params"):
response._hidden_params["model_id"] = (
self.logging_obj.model_call_details.get(
"litellm_params", {}
)
.get("model_info", {})
.get("id", None)
)
return response
except StopIteration:
raise # Re-raise StopIteration
@ -7467,6 +7495,16 @@ class CustomStreamWrapper:
processed_chunk,
)
)
# RETURN RESULT
if hasattr(processed_chunk, "_hidden_params"):
model_id = (
self.logging_obj.model_call_details.get(
"litellm_params", {}
)
.get("model_info", {})
.get("id", None)
)
processed_chunk._hidden_params["model_id"] = model_id
return processed_chunk
raise StopAsyncIteration
else: # temporary patch for non-aiohttp async calls