fix(utils.py): improved predibase exception mapping

adds unit testing + better coverage for predibase errors
This commit is contained in:
Krrish Dholakia 2024-06-08 14:32:43 -07:00
parent 192dfbcd63
commit 39ee6be477
11 changed files with 220 additions and 46 deletions

View file

@ -3,6 +3,7 @@
from functools import partial
import os, types
import traceback
import json
from enum import Enum
import requests, copy # type: ignore
@ -242,12 +243,12 @@ class PredibaseChatCompletion(BaseLLM):
"details" in completion_response
and "tokens" in completion_response["details"]
):
model_response.choices[0].finish_reason = completion_response[
"details"
]["finish_reason"]
model_response.choices[0].finish_reason = map_finish_reason(
completion_response["details"]["finish_reason"]
)
sum_logprob = 0
for token in completion_response["details"]["tokens"]:
if token["logprob"] != None:
if token["logprob"] is not None:
sum_logprob += token["logprob"]
model_response["choices"][0][
"message"
@ -265,7 +266,7 @@ class PredibaseChatCompletion(BaseLLM):
):
sum_logprob = 0
for token in item["tokens"]:
if token["logprob"] != None:
if token["logprob"] is not None:
sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0:
message_obj = Message(
@ -275,7 +276,7 @@ class PredibaseChatCompletion(BaseLLM):
else:
message_obj = Message(content=None)
choice_obj = Choices(
finish_reason=item["finish_reason"],
finish_reason=map_finish_reason(item["finish_reason"]),
index=idx + 1,
message=message_obj,
)
@ -285,10 +286,8 @@ class PredibaseChatCompletion(BaseLLM):
## CALCULATING USAGE
prompt_tokens = 0
try:
prompt_tokens = len(
encoding.encode(model_response["choices"][0]["message"]["content"])
) ##[TODO] use a model-specific tokenizer here
except:
prompt_tokens = litellm.token_counter(messages=messages)
except Exception:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
output_text = model_response["choices"][0]["message"].get("content", "")
@ -331,6 +330,7 @@ class PredibaseChatCompletion(BaseLLM):
logging_obj,
optional_params: dict,
tenant_id: str,
timeout: Union[float, httpx.Timeout],
acompletion=None,
litellm_params=None,
logger_fn=None,
@ -340,6 +340,7 @@ class PredibaseChatCompletion(BaseLLM):
completion_url = ""
input_text = ""
base_url = "https://serving.app.predibase.com"
if "https" in model:
completion_url = model
elif api_base:
@ -349,7 +350,7 @@ class PredibaseChatCompletion(BaseLLM):
completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}"
if optional_params.get("stream", False) == True:
if optional_params.get("stream", False) is True:
completion_url += "/generate_stream"
else:
completion_url += "/generate"
@ -393,9 +394,9 @@ class PredibaseChatCompletion(BaseLLM):
},
)
## COMPLETION CALL
if acompletion == True:
if acompletion is True:
### ASYNC STREAMING
if stream == True:
if stream is True:
return self.async_streaming(
model=model,
messages=messages,
@ -410,6 +411,7 @@ class PredibaseChatCompletion(BaseLLM):
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
) # type: ignore
else:
### ASYNC COMPLETION
@ -428,10 +430,11 @@ class PredibaseChatCompletion(BaseLLM):
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
) # type: ignore
### SYNC STREAMING
if stream == True:
if stream is True:
response = requests.post(
completion_url,
headers=headers,
@ -452,7 +455,6 @@ class PredibaseChatCompletion(BaseLLM):
headers=headers,
data=json.dumps(data),
)
return self.process_response(
model=model,
response=response,
@ -480,23 +482,26 @@ class PredibaseChatCompletion(BaseLLM):
stream,
data: dict,
optional_params: dict,
timeout: Union[float, httpx.Timeout],
litellm_params=None,
logger_fn=None,
headers={},
) -> ModelResponse:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout))
try:
response = await self.async_handler.post(
response = await async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)
except httpx.HTTPStatusError as e:
raise PredibaseError(
status_code=e.response.status_code, message=e.response.text
status_code=e.response.status_code,
message="HTTPStatusError - {}".format(e.response.text),
)
except Exception as e:
raise PredibaseError(status_code=500, message=str(e))
raise PredibaseError(
status_code=500, message="{}\n{}".format(str(e), traceback.format_exc())
)
return self.process_response(
model=model,
response=response,
@ -522,6 +527,7 @@ class PredibaseChatCompletion(BaseLLM):
api_key,
logging_obj,
data: dict,
timeout: Union[float, httpx.Timeout],
optional_params=None,
litellm_params=None,
logger_fn=None,