forked from phoenix/litellm-mirror
Merge pull request #3571 from BerriAI/litellm_hf_classifier_support
Huggingface classifier support
This commit is contained in:
commit
1aa567f3b5
6 changed files with 415 additions and 64 deletions
|
@ -21,6 +21,11 @@ This is done by adding the "huggingface/" prefix to `model`, example `completion
|
|||
<Tabs>
|
||||
<TabItem value="tgi" label="Text-generation-interface (TGI)">
|
||||
|
||||
By default, LiteLLM will assume a huggingface call follows the TGI format.
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
|
@ -40,9 +45,58 @@ response = completion(
|
|||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add models to your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: wizard-coder
|
||||
litellm_params:
|
||||
model: huggingface/WizardLM/WizardCoder-Python-34B-V1.0
|
||||
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||
```
|
||||
|
||||
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml --debug
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "wizard-coder",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I like you!"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
</TabItem>
|
||||
<TabItem value="conv" label="Conversational-task (BlenderBot, etc.)">
|
||||
|
||||
Append `conversational` to the model name
|
||||
|
||||
e.g. `huggingface/conversational/<model-name>`
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
|
@ -54,7 +108,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
|
|||
|
||||
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
|
||||
response = completion(
|
||||
model="huggingface/facebook/blenderbot-400M-distill",
|
||||
model="huggingface/conversational/facebook/blenderbot-400M-distill",
|
||||
messages=messages,
|
||||
api_base="https://my-endpoint.huggingface.cloud"
|
||||
)
|
||||
|
@ -62,7 +116,123 @@ response = completion(
|
|||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="none" label="Non TGI/Conversational-task LLMs">
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add models to your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: blenderbot
|
||||
litellm_params:
|
||||
model: huggingface/conversational/facebook/blenderbot-400M-distill
|
||||
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||
```
|
||||
|
||||
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml --debug
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "blenderbot",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I like you!"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
</TabItem>
|
||||
<TabItem value="classification" label="Text Classification">
|
||||
|
||||
Append `text-classification` to the model name
|
||||
|
||||
e.g. `huggingface/text-classification/<model-name>`
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
|
||||
# [OPTIONAL] set env var
|
||||
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
|
||||
|
||||
messages = [{ "content": "I like you, I love you!","role": "user"}]
|
||||
|
||||
# e.g. Call 'shahrukhx01/question-vs-statement-classifier' hosted on HF Inference endpoints
|
||||
response = completion(
|
||||
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
|
||||
messages=messages,
|
||||
api_base="https://my-endpoint.endpoints.huggingface.cloud",
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add models to your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: bert-classifier
|
||||
litellm_params:
|
||||
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
|
||||
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||
```
|
||||
|
||||
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml --debug
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "bert-classifier",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I like you!"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
</TabItem>
|
||||
<TabItem value="none" label="Text Generation (NOT TGI)">
|
||||
|
||||
Append `text-generation` to the model name
|
||||
|
||||
e.g. `huggingface/text-generation/<model-name>`
|
||||
|
||||
```python
|
||||
import os
|
||||
|
@ -75,7 +245,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
|
|||
|
||||
# e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints
|
||||
response = completion(
|
||||
model="huggingface/roneneldan/TinyStories-3M",
|
||||
model="huggingface/text-generation/roneneldan/TinyStories-3M",
|
||||
messages=messages,
|
||||
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
|
||||
)
|
||||
|
|
|
@ -6,10 +6,12 @@ import httpx, requests
|
|||
from .base import BaseLLM
|
||||
import time
|
||||
import litellm
|
||||
from typing import Callable, Dict, List, Any
|
||||
from typing import Callable, Dict, List, Any, Literal
|
||||
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
|
||||
from typing import Optional
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from litellm.types.completion import ChatCompletionMessageToolCallParam
|
||||
import enum
|
||||
|
||||
|
||||
class HuggingfaceError(Exception):
|
||||
|
@ -39,11 +41,29 @@ class HuggingfaceError(Exception):
|
|||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
hf_task_list = [
|
||||
"text-generation-inference",
|
||||
"conversational",
|
||||
"text-classification",
|
||||
"text-generation",
|
||||
]
|
||||
|
||||
hf_tasks = Literal[
|
||||
"text-generation-inference",
|
||||
"conversational",
|
||||
"text-classification",
|
||||
"text-generation",
|
||||
]
|
||||
|
||||
|
||||
class HuggingfaceConfig:
|
||||
"""
|
||||
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
|
||||
"""
|
||||
|
||||
hf_task: Optional[hf_tasks] = (
|
||||
None # litellm-specific param, used to know the api spec to use when calling huggingface api
|
||||
)
|
||||
best_of: Optional[int] = None
|
||||
decoder_input_details: Optional[bool] = None
|
||||
details: Optional[bool] = True # enables returning logprobs + best of
|
||||
|
@ -101,6 +121,51 @@ class HuggingfaceConfig:
|
|||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"stop",
|
||||
"n",
|
||||
"echo",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
||||
if param == "temperature":
|
||||
if value == 0.0 or value == 0:
|
||||
# hugging face exception raised when temp==0
|
||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
||||
value = 0.01
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "n":
|
||||
optional_params["best_of"] = value
|
||||
optional_params["do_sample"] = (
|
||||
True # Need to sample if you want best of for hf inference endpoints
|
||||
)
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop"] = value
|
||||
if param == "max_tokens":
|
||||
# HF TGI raises the following exception when max_new_tokens==0
|
||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
||||
if value == 0:
|
||||
value = 1
|
||||
optional_params["max_new_tokens"] = value
|
||||
if param == "echo":
|
||||
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
||||
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
||||
optional_params["decoder_input_details"] = True
|
||||
return optional_params
|
||||
|
||||
|
||||
def output_parser(generated_text: str):
|
||||
"""
|
||||
|
@ -162,16 +227,18 @@ def read_tgi_conv_models():
|
|||
return set(), set()
|
||||
|
||||
|
||||
def get_hf_task_for_model(model):
|
||||
def get_hf_task_for_model(model: str) -> hf_tasks:
|
||||
# read text file, cast it to set
|
||||
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
||||
if model.split("/")[0] in hf_task_list:
|
||||
return model.split("/")[0] # type: ignore
|
||||
tgi_models, conversational_models = read_tgi_conv_models()
|
||||
if model in tgi_models:
|
||||
return "text-generation-inference"
|
||||
elif model in conversational_models:
|
||||
return "conversational"
|
||||
elif "roneneldan/TinyStories" in model:
|
||||
return None
|
||||
return "text-generation"
|
||||
else:
|
||||
return "text-generation-inference" # default to tgi
|
||||
|
||||
|
@ -202,7 +269,7 @@ class Huggingface(BaseLLM):
|
|||
self,
|
||||
completion_response,
|
||||
model_response,
|
||||
task,
|
||||
task: hf_tasks,
|
||||
optional_params,
|
||||
encoding,
|
||||
input_text,
|
||||
|
@ -270,6 +337,10 @@ class Huggingface(BaseLLM):
|
|||
)
|
||||
choices_list.append(choice_obj)
|
||||
model_response["choices"].extend(choices_list)
|
||||
elif task == "text-classification":
|
||||
model_response["choices"][0]["message"]["content"] = json.dumps(
|
||||
completion_response
|
||||
)
|
||||
else:
|
||||
if len(completion_response[0]["generated_text"]) > 0:
|
||||
model_response["choices"][0]["message"]["content"] = output_parser(
|
||||
|
@ -333,6 +404,12 @@ class Huggingface(BaseLLM):
|
|||
try:
|
||||
headers = self.validate_environment(api_key, headers)
|
||||
task = get_hf_task_for_model(model)
|
||||
## VALIDATE API FORMAT
|
||||
if task is None or not isinstance(task, str) or task not in hf_task_list:
|
||||
raise Exception(
|
||||
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
|
||||
)
|
||||
|
||||
print_verbose(f"{model}, {task}")
|
||||
completion_url = ""
|
||||
input_text = ""
|
||||
|
@ -433,14 +510,15 @@ class Huggingface(BaseLLM):
|
|||
inference_params.pop("return_full_text")
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": inference_params,
|
||||
"stream": ( # type: ignore
|
||||
}
|
||||
if task == "text-generation-inference":
|
||||
data["parameters"] = inference_params
|
||||
data["stream"] = ( # type: ignore
|
||||
True
|
||||
if "stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
else False
|
||||
),
|
||||
}
|
||||
)
|
||||
input_text = prompt
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
|
@ -531,10 +609,10 @@ class Huggingface(BaseLLM):
|
|||
isinstance(completion_response, dict)
|
||||
and "error" in completion_response
|
||||
):
|
||||
print_verbose(f"completion error: {completion_response['error']}")
|
||||
print_verbose(f"completion error: {completion_response['error']}") # type: ignore
|
||||
print_verbose(f"response.status_code: {response.status_code}")
|
||||
raise HuggingfaceError(
|
||||
message=completion_response["error"],
|
||||
message=completion_response["error"], # type: ignore
|
||||
status_code=response.status_code,
|
||||
)
|
||||
return self.convert_to_model_response_object(
|
||||
|
@ -563,7 +641,7 @@ class Huggingface(BaseLLM):
|
|||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
task: str,
|
||||
task: hf_tasks,
|
||||
encoding: Any,
|
||||
input_text: str,
|
||||
model: str,
|
||||
|
|
|
@ -13,6 +13,7 @@ import litellm
|
|||
from litellm import embedding, completion, completion_cost, Timeout
|
||||
from litellm import RateLimitError
|
||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# litellm.num_retries=3
|
||||
litellm.cache = None
|
||||
|
@ -1137,7 +1138,7 @@ def test_get_hf_task_for_model():
|
|||
model = "roneneldan/TinyStories-3M"
|
||||
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
||||
print(f"model:{model}, model type: {model_type}")
|
||||
assert model_type == None
|
||||
assert model_type == "text-generation"
|
||||
|
||||
|
||||
# test_get_hf_task_for_model()
|
||||
|
@ -1145,12 +1146,89 @@ def test_get_hf_task_for_model():
|
|||
# ################### Hugging Face TGI models ########################
|
||||
# # TGI model
|
||||
# # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b
|
||||
def hf_test_completion_tgi():
|
||||
# litellm.set_verbose=True
|
||||
def tgi_mock_post(url, data=None, json=None, headers=None):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json.return_value = [
|
||||
{
|
||||
"generated_text": "<|assistant|>\nI'm",
|
||||
"details": {
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"seed": None,
|
||||
"prefill": [],
|
||||
"tokens": [
|
||||
{
|
||||
"id": 28789,
|
||||
"text": "<",
|
||||
"logprob": -0.025222778,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 28766,
|
||||
"text": "|",
|
||||
"logprob": -0.000003695488,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 489,
|
||||
"text": "ass",
|
||||
"logprob": -0.0000019073486,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 11143,
|
||||
"text": "istant",
|
||||
"logprob": -0.000002026558,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 28766,
|
||||
"text": "|",
|
||||
"logprob": -0.0000015497208,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 28767,
|
||||
"text": ">",
|
||||
"logprob": -0.0000011920929,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"text": "\n",
|
||||
"logprob": -0.00009703636,
|
||||
"special": False,
|
||||
},
|
||||
{"id": 28737, "text": "I", "logprob": -0.1953125, "special": False},
|
||||
{
|
||||
"id": 28742,
|
||||
"text": "'",
|
||||
"logprob": -0.88183594,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 28719,
|
||||
"text": "m",
|
||||
"logprob": -0.00032639503,
|
||||
"special": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
return mock_response
|
||||
|
||||
|
||||
def test_hf_test_completion_tgi():
|
||||
litellm.set_verbose = True
|
||||
try:
|
||||
with patch("requests.post", side_effect=tgi_mock_post):
|
||||
response = completion(
|
||||
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
|
||||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||
max_tokens=10,
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
|
@ -1191,6 +1269,40 @@ def hf_test_completion_tgi():
|
|||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
# hf_test_completion_none_task()
|
||||
|
||||
|
||||
def mock_post(url, data=None, json=None, headers=None):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json.return_value = [
|
||||
[
|
||||
{"label": "LABEL_0", "score": 0.9990691542625427},
|
||||
{"label": "LABEL_1", "score": 0.0009308889275416732},
|
||||
]
|
||||
]
|
||||
return mock_response
|
||||
|
||||
|
||||
def test_hf_classifier_task():
|
||||
try:
|
||||
with patch("requests.post", side_effect=mock_post):
|
||||
litellm.set_verbose = True
|
||||
user_message = "I like you. I love you"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
response = completion(
|
||||
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
|
||||
messages=messages,
|
||||
)
|
||||
print(f"response: {response}")
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
assert isinstance(response.choices[0], litellm.Choices)
|
||||
assert response.choices[0].message.content is not None
|
||||
assert isinstance(response.choices[0].message.content, str)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {str(e)}")
|
||||
|
||||
|
||||
########################### End of Hugging Face Tests ##############################################
|
||||
# def test_completion_hf_api():
|
||||
# # failing on circle ci commenting out
|
||||
|
|
|
@ -3,7 +3,27 @@ from litellm import get_optional_params
|
|||
|
||||
litellm.add_function_to_prompt = True
|
||||
optional_params = get_optional_params(
|
||||
tools= [{'type': 'function', 'function': {'description': 'Get the current weather in a given location', 'name': 'get_current_weather', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state, e.g. San Francisco, CA'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}],
|
||||
tool_choice= 'auto',
|
||||
model="",
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": "Get the current weather in a given location",
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
tool_choice="auto",
|
||||
)
|
||||
assert optional_params is not None
|
|
@ -86,6 +86,7 @@ def test_azure_optional_params_embeddings():
|
|||
def test_azure_gpt_optional_params_gpt_vision():
|
||||
# for OpenAI, Azure all extra params need to get passed as extra_body to OpenAI python. We assert we actually set extra_body here
|
||||
optional_params = litellm.utils.get_optional_params(
|
||||
model="",
|
||||
user="John",
|
||||
custom_llm_provider="azure",
|
||||
max_tokens=10,
|
||||
|
@ -125,6 +126,7 @@ def test_azure_gpt_optional_params_gpt_vision():
|
|||
def test_azure_gpt_optional_params_gpt_vision_with_extra_body():
|
||||
# if user passes extra_body, we should not over write it, we should pass it along to OpenAI python
|
||||
optional_params = litellm.utils.get_optional_params(
|
||||
model="",
|
||||
user="John",
|
||||
custom_llm_provider="azure",
|
||||
max_tokens=10,
|
||||
|
@ -167,6 +169,7 @@ def test_azure_gpt_optional_params_gpt_vision_with_extra_body():
|
|||
|
||||
def test_openai_extra_headers():
|
||||
optional_params = litellm.utils.get_optional_params(
|
||||
model="",
|
||||
user="John",
|
||||
custom_llm_provider="openai",
|
||||
max_tokens=10,
|
||||
|
|
|
@ -4871,6 +4871,7 @@ def get_optional_params_embeddings(
|
|||
def get_optional_params(
|
||||
# use the openai defaults
|
||||
# https://platform.openai.com/docs/api-reference/chat/create
|
||||
model: str,
|
||||
functions=None,
|
||||
function_call=None,
|
||||
temperature=None,
|
||||
|
@ -4884,7 +4885,6 @@ def get_optional_params(
|
|||
frequency_penalty=None,
|
||||
logit_bias=None,
|
||||
user=None,
|
||||
model=None,
|
||||
custom_llm_provider="",
|
||||
response_format=None,
|
||||
seed=None,
|
||||
|
@ -4913,7 +4913,7 @@ def get_optional_params(
|
|||
|
||||
passed_params[k] = v
|
||||
|
||||
optional_params = {}
|
||||
optional_params: Dict = {}
|
||||
|
||||
common_auth_dict = litellm.common_cloud_provider_auth_params
|
||||
if custom_llm_provider in common_auth_dict["providers"]:
|
||||
|
@ -5187,41 +5187,9 @@ def get_optional_params(
|
|||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
||||
if temperature is not None:
|
||||
if temperature == 0.0 or temperature == 0:
|
||||
# hugging face exception raised when temp==0
|
||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
||||
temperature = 0.01
|
||||
optional_params["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
optional_params["top_p"] = top_p
|
||||
if n is not None:
|
||||
optional_params["best_of"] = n
|
||||
optional_params["do_sample"] = (
|
||||
True # Need to sample if you want best of for hf inference endpoints
|
||||
optional_params = litellm.HuggingfaceConfig().map_openai_params(
|
||||
non_default_params=non_default_params, optional_params=optional_params
|
||||
)
|
||||
if stream is not None:
|
||||
optional_params["stream"] = stream
|
||||
if stop is not None:
|
||||
optional_params["stop"] = stop
|
||||
if max_tokens is not None:
|
||||
# HF TGI raises the following exception when max_new_tokens==0
|
||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
||||
if max_tokens == 0:
|
||||
max_tokens = 1
|
||||
optional_params["max_new_tokens"] = max_tokens
|
||||
if n is not None:
|
||||
optional_params["best_of"] = n
|
||||
if presence_penalty is not None:
|
||||
optional_params["repetition_penalty"] = presence_penalty
|
||||
if "echo" in passed_params:
|
||||
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
||||
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
||||
optional_params["decoder_input_details"] = special_params["echo"]
|
||||
passed_params.pop(
|
||||
"echo", None
|
||||
) # since we handle translating echo, we should not send it to TGI request
|
||||
elif custom_llm_provider == "together_ai":
|
||||
## check if unsupported param passed in
|
||||
supported_params = get_supported_openai_params(
|
||||
|
@ -6181,7 +6149,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
|||
"seed",
|
||||
]
|
||||
elif custom_llm_provider == "huggingface":
|
||||
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||
return litellm.HuggingfaceConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "together_ai":
|
||||
return [
|
||||
"stream",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue