forked from phoenix/litellm-mirror
fix(predibase.py): support json schema on predibase
This commit is contained in:
parent
6889a4c0dd
commit
e813e984f7
3 changed files with 67 additions and 18 deletions
|
@ -15,6 +15,8 @@ import httpx # type: ignore
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
import litellm.litellm_core_utils
|
||||||
|
import litellm.litellm_core_utils.litellm_logging
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
|
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
|
||||||
|
@ -145,7 +147,49 @@ class PredibaseConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_supported_openai_params(self):
|
def get_supported_openai_params(self):
|
||||||
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
return [
|
||||||
|
"stream",
|
||||||
|
"temperature",
|
||||||
|
"max_tokens",
|
||||||
|
"top_p",
|
||||||
|
"stop",
|
||||||
|
"n",
|
||||||
|
"response_format",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
||||||
|
if param == "temperature":
|
||||||
|
if value == 0.0 or value == 0:
|
||||||
|
# hugging face exception raised when temp==0
|
||||||
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
||||||
|
value = 0.01
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
if param == "n":
|
||||||
|
optional_params["best_of"] = value
|
||||||
|
optional_params["do_sample"] = (
|
||||||
|
True # Need to sample if you want best of for hf inference endpoints
|
||||||
|
)
|
||||||
|
if param == "stream":
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "stop":
|
||||||
|
optional_params["stop"] = value
|
||||||
|
if param == "max_tokens":
|
||||||
|
# HF TGI raises the following exception when max_new_tokens==0
|
||||||
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
||||||
|
if value == 0:
|
||||||
|
value = 1
|
||||||
|
optional_params["max_new_tokens"] = value
|
||||||
|
if param == "echo":
|
||||||
|
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
||||||
|
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
||||||
|
optional_params["decoder_input_details"] = True
|
||||||
|
if param == "response_format":
|
||||||
|
optional_params["response_format"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
class PredibaseChatCompletion(BaseLLM):
|
class PredibaseChatCompletion(BaseLLM):
|
||||||
|
@ -224,15 +268,16 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if (
|
if not isinstance(completion_response, dict):
|
||||||
not isinstance(completion_response, dict)
|
|
||||||
or "generated_text" not in completion_response
|
|
||||||
):
|
|
||||||
raise PredibaseError(
|
raise PredibaseError(
|
||||||
status_code=422,
|
status_code=422,
|
||||||
message=f"response is not in expected format - {completion_response}",
|
message=f"'completion_response' is not a dictionary - {completion_response}",
|
||||||
|
)
|
||||||
|
elif "generated_text" not in completion_response:
|
||||||
|
raise PredibaseError(
|
||||||
|
status_code=422,
|
||||||
|
message=f"'generated_text' is not a key response dictionary - {completion_response}",
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(completion_response["generated_text"]) > 0:
|
if len(completion_response["generated_text"]) > 0:
|
||||||
model_response["choices"][0]["message"]["content"] = self.output_parser(
|
model_response["choices"][0]["message"]["content"] = self.output_parser(
|
||||||
completion_response["generated_text"]
|
completion_response["generated_text"]
|
||||||
|
|
|
@ -14,14 +14,10 @@ model_list:
|
||||||
- model_name: fake-openai-endpoint
|
- model_name: fake-openai-endpoint
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: predibase/llama-3-8b-instruct
|
model: predibase/llama-3-8b-instruct
|
||||||
# api_base: "http://0.0.0.0:8081"
|
api_base: "http://0.0.0.0:8081"
|
||||||
api_key: os.environ/PREDIBASE_API_KEY
|
api_key: os.environ/PREDIBASE_API_KEY
|
||||||
tenant_id: os.environ/PREDIBASE_TENANT_ID
|
tenant_id: os.environ/PREDIBASE_TENANT_ID
|
||||||
adapter_id: qwoiqjdoqin
|
|
||||||
max_retries: 0
|
|
||||||
temperature: 0.1
|
|
||||||
max_new_tokens: 256
|
max_new_tokens: 256
|
||||||
return_full_text: false
|
|
||||||
|
|
||||||
# - litellm_params:
|
# - litellm_params:
|
||||||
# api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
# api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||||
|
@ -97,8 +93,8 @@ assistant_settings:
|
||||||
router_settings:
|
router_settings:
|
||||||
enable_pre_call_checks: true
|
enable_pre_call_checks: true
|
||||||
|
|
||||||
general_settings:
|
# general_settings:
|
||||||
alerting: ["slack"]
|
# # alerting: ["slack"]
|
||||||
enable_jwt_auth: True
|
# enable_jwt_auth: True
|
||||||
litellm_jwtauth:
|
# litellm_jwtauth:
|
||||||
team_id_jwt_field: "client_id"
|
# team_id_jwt_field: "client_id"
|
|
@ -2609,7 +2609,15 @@ def get_optional_params(
|
||||||
optional_params["top_p"] = top_p
|
optional_params["top_p"] = top_p
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
optional_params["stop_sequences"] = stop
|
optional_params["stop_sequences"] = stop
|
||||||
elif custom_llm_provider == "huggingface" or custom_llm_provider == "predibase":
|
elif custom_llm_provider == "predibase":
|
||||||
|
supported_params = get_supported_openai_params(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
optional_params = litellm.PredibaseConfig().map_openai_params(
|
||||||
|
non_default_params=non_default_params, optional_params=optional_params
|
||||||
|
)
|
||||||
|
elif custom_llm_provider == "huggingface":
|
||||||
## check if unsupported param passed in
|
## check if unsupported param passed in
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue