mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(utils.py): improved predibase exception mapping
adds unit testing + better coverage for predibase errors
This commit is contained in:
parent
192dfbcd63
commit
39ee6be477
11 changed files with 220 additions and 46 deletions
|
@ -3,6 +3,7 @@
|
|||
|
||||
from functools import partial
|
||||
import os, types
|
||||
import traceback
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests, copy # type: ignore
|
||||
|
@ -242,12 +243,12 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
"details" in completion_response
|
||||
and "tokens" in completion_response["details"]
|
||||
):
|
||||
model_response.choices[0].finish_reason = completion_response[
|
||||
"details"
|
||||
]["finish_reason"]
|
||||
model_response.choices[0].finish_reason = map_finish_reason(
|
||||
completion_response["details"]["finish_reason"]
|
||||
)
|
||||
sum_logprob = 0
|
||||
for token in completion_response["details"]["tokens"]:
|
||||
if token["logprob"] != None:
|
||||
if token["logprob"] is not None:
|
||||
sum_logprob += token["logprob"]
|
||||
model_response["choices"][0][
|
||||
"message"
|
||||
|
@ -265,7 +266,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
):
|
||||
sum_logprob = 0
|
||||
for token in item["tokens"]:
|
||||
if token["logprob"] != None:
|
||||
if token["logprob"] is not None:
|
||||
sum_logprob += token["logprob"]
|
||||
if len(item["generated_text"]) > 0:
|
||||
message_obj = Message(
|
||||
|
@ -275,7 +276,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
else:
|
||||
message_obj = Message(content=None)
|
||||
choice_obj = Choices(
|
||||
finish_reason=item["finish_reason"],
|
||||
finish_reason=map_finish_reason(item["finish_reason"]),
|
||||
index=idx + 1,
|
||||
message=message_obj,
|
||||
)
|
||||
|
@ -285,10 +286,8 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
## CALCULATING USAGE
|
||||
prompt_tokens = 0
|
||||
try:
|
||||
prompt_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"]["content"])
|
||||
) ##[TODO] use a model-specific tokenizer here
|
||||
except:
|
||||
prompt_tokens = litellm.token_counter(messages=messages)
|
||||
except Exception:
|
||||
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||
pass
|
||||
output_text = model_response["choices"][0]["message"].get("content", "")
|
||||
|
@ -331,6 +330,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
logging_obj,
|
||||
optional_params: dict,
|
||||
tenant_id: str,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
@ -340,6 +340,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
completion_url = ""
|
||||
input_text = ""
|
||||
base_url = "https://serving.app.predibase.com"
|
||||
|
||||
if "https" in model:
|
||||
completion_url = model
|
||||
elif api_base:
|
||||
|
@ -349,7 +350,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
|
||||
completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}"
|
||||
|
||||
if optional_params.get("stream", False) == True:
|
||||
if optional_params.get("stream", False) is True:
|
||||
completion_url += "/generate_stream"
|
||||
else:
|
||||
completion_url += "/generate"
|
||||
|
@ -393,9 +394,9 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
if acompletion == True:
|
||||
if acompletion is True:
|
||||
### ASYNC STREAMING
|
||||
if stream == True:
|
||||
if stream is True:
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -410,6 +411,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
) # type: ignore
|
||||
else:
|
||||
### ASYNC COMPLETION
|
||||
|
@ -428,10 +430,11 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
) # type: ignore
|
||||
|
||||
### SYNC STREAMING
|
||||
if stream == True:
|
||||
if stream is True:
|
||||
response = requests.post(
|
||||
completion_url,
|
||||
headers=headers,
|
||||
|
@ -452,7 +455,6 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
)
|
||||
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
|
@ -480,23 +482,26 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
stream,
|
||||
data: dict,
|
||||
optional_params: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
) -> ModelResponse:
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
|
||||
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout))
|
||||
try:
|
||||
response = await self.async_handler.post(
|
||||
response = await async_handler.post(
|
||||
api_base, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise PredibaseError(
|
||||
status_code=e.response.status_code, message=e.response.text
|
||||
status_code=e.response.status_code,
|
||||
message="HTTPStatusError - {}".format(e.response.text),
|
||||
)
|
||||
except Exception as e:
|
||||
raise PredibaseError(status_code=500, message=str(e))
|
||||
raise PredibaseError(
|
||||
status_code=500, message="{}\n{}".format(str(e), traceback.format_exc())
|
||||
)
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
|
@ -522,6 +527,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
api_key,
|
||||
logging_obj,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue