From e813e984f74ea09ea92646c44c5a5ab7a30bbff0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 25 Jun 2024 16:03:47 -0700 Subject: [PATCH] fix(predibase.py): support json schema on predibase --- litellm/llms/predibase.py | 59 ++++++++++++++++++++++--- litellm/proxy/_super_secret_config.yaml | 16 +++---- litellm/utils.py | 10 ++++- 3 files changed, 67 insertions(+), 18 deletions(-) diff --git a/litellm/llms/predibase.py b/litellm/llms/predibase.py index 7a137da70..534f8e26f 100644 --- a/litellm/llms/predibase.py +++ b/litellm/llms/predibase.py @@ -15,6 +15,8 @@ import httpx # type: ignore import requests # type: ignore import litellm +import litellm.litellm_core_utils +import litellm.litellm_core_utils.litellm_logging from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage @@ -145,7 +147,49 @@ class PredibaseConfig: } def get_supported_openai_params(self): - return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] + return [ + "stream", + "temperature", + "max_tokens", + "top_p", + "stop", + "n", + "response_format", + ] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None + if param == "temperature": + if value == 0.0 or value == 0: + # hugging face exception raised when temp==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive + value = 0.01 + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "n": + optional_params["best_of"] = value + optional_params["do_sample"] = ( + True # Need to sample if you want best of for hf inference endpoints + ) + if param == "stream": + optional_params["stream"] = value + if param == "stop": + optional_params["stop"] = value + if param == "max_tokens": + # HF TGI raises the following exception when max_new_tokens==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive + if value == 0: + value = 1 + optional_params["max_new_tokens"] = value + if param == "echo": + # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details + # Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False + optional_params["decoder_input_details"] = True + if param == "response_format": + optional_params["response_format"] = value + return optional_params class PredibaseChatCompletion(BaseLLM): @@ -224,15 +268,16 @@ class PredibaseChatCompletion(BaseLLM): status_code=response.status_code, ) else: - if ( - not isinstance(completion_response, dict) - or "generated_text" not in completion_response - ): + if not isinstance(completion_response, dict): raise PredibaseError( status_code=422, - message=f"response is not in expected format - {completion_response}", + message=f"'completion_response' is not a dictionary - {completion_response}", + ) + elif "generated_text" not in completion_response: + raise PredibaseError( + status_code=422, + message=f"'generated_text' is not a key response dictionary - {completion_response}", ) - if len(completion_response["generated_text"]) > 0: model_response["choices"][0]["message"]["content"] = self.output_parser( completion_response["generated_text"] diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index 94df97c54..2060f61ca 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -14,14 +14,10 @@ model_list: - model_name: fake-openai-endpoint litellm_params: model: predibase/llama-3-8b-instruct - # api_base: "http://0.0.0.0:8081" + api_base: "http://0.0.0.0:8081" api_key: os.environ/PREDIBASE_API_KEY tenant_id: os.environ/PREDIBASE_TENANT_ID - adapter_id: qwoiqjdoqin - max_retries: 0 - temperature: 0.1 max_new_tokens: 256 - return_full_text: false # - litellm_params: # api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ @@ -97,8 +93,8 @@ assistant_settings: router_settings: enable_pre_call_checks: true -general_settings: - alerting: ["slack"] - enable_jwt_auth: True - litellm_jwtauth: - team_id_jwt_field: "client_id" \ No newline at end of file +# general_settings: +# # alerting: ["slack"] +# enable_jwt_auth: True +# litellm_jwtauth: +# team_id_jwt_field: "client_id" \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 00833003b..4465c5b0a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2609,7 +2609,15 @@ def get_optional_params( optional_params["top_p"] = top_p if stop is not None: optional_params["stop_sequences"] = stop - elif custom_llm_provider == "huggingface" or custom_llm_provider == "predibase": + elif custom_llm_provider == "predibase": + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.PredibaseConfig().map_openai_params( + non_default_params=non_default_params, optional_params=optional_params + ) + elif custom_llm_provider == "huggingface": ## check if unsupported param passed in supported_params = get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider