fix(predibase.py): support json schema on predibase

This commit is contained in:
Krrish Dholakia 2024-06-25 16:03:47 -07:00
parent 1e51b8894f
commit 91bbef4bcd
3 changed files with 67 additions and 18 deletions

View file

@ -15,6 +15,8 @@ import httpx # type: ignore
import requests # type: ignore
import litellm
import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
@ -145,7 +147,49 @@ class PredibaseConfig:
}
def get_supported_openai_params(self):
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"stop",
"n",
"response_format",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if param == "temperature":
if value == 0.0 or value == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
value = 0.01
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "n":
optional_params["best_of"] = value
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "max_tokens":
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if value == 0:
value = 1
optional_params["max_new_tokens"] = value
if param == "echo":
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = True
if param == "response_format":
optional_params["response_format"] = value
return optional_params
class PredibaseChatCompletion(BaseLLM):
@ -224,15 +268,16 @@ class PredibaseChatCompletion(BaseLLM):
status_code=response.status_code,
)
else:
if (
not isinstance(completion_response, dict)
or "generated_text" not in completion_response
):
if not isinstance(completion_response, dict):
raise PredibaseError(
status_code=422,
message=f"response is not in expected format - {completion_response}",
message=f"'completion_response' is not a dictionary - {completion_response}",
)
elif "generated_text" not in completion_response:
raise PredibaseError(
status_code=422,
message=f"'generated_text' is not a key response dictionary - {completion_response}",
)
if len(completion_response["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = self.output_parser(
completion_response["generated_text"]