forked from phoenix/litellm-mirror
Merge pull request #3571 from BerriAI/litellm_hf_classifier_support
Huggingface classifier support
This commit is contained in:
commit
1aa567f3b5
6 changed files with 415 additions and 64 deletions
|
@ -21,6 +21,11 @@ This is done by adding the "huggingface/" prefix to `model`, example `completion
|
||||||
<Tabs>
|
<Tabs>
|
||||||
<TabItem value="tgi" label="Text-generation-interface (TGI)">
|
<TabItem value="tgi" label="Text-generation-interface (TGI)">
|
||||||
|
|
||||||
|
By default, LiteLLM will assume a huggingface call follows the TGI format.
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
@ -40,9 +45,58 @@ response = completion(
|
||||||
print(response)
|
print(response)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: wizard-coder
|
||||||
|
litellm_params:
|
||||||
|
model: huggingface/WizardLM/WizardCoder-Python-34B-V1.0
|
||||||
|
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||||
|
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "wizard-coder",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "I like you!"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="conv" label="Conversational-task (BlenderBot, etc.)">
|
<TabItem value="conv" label="Conversational-task (BlenderBot, etc.)">
|
||||||
|
|
||||||
|
Append `conversational` to the model name
|
||||||
|
|
||||||
|
e.g. `huggingface/conversational/<model-name>`
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
@ -54,7 +108,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
|
||||||
|
|
||||||
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
|
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
|
||||||
response = completion(
|
response = completion(
|
||||||
model="huggingface/facebook/blenderbot-400M-distill",
|
model="huggingface/conversational/facebook/blenderbot-400M-distill",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base="https://my-endpoint.huggingface.cloud"
|
api_base="https://my-endpoint.huggingface.cloud"
|
||||||
)
|
)
|
||||||
|
@ -62,7 +116,123 @@ response = completion(
|
||||||
print(response)
|
print(response)
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="none" label="Non TGI/Conversational-task LLMs">
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: blenderbot
|
||||||
|
litellm_params:
|
||||||
|
model: huggingface/conversational/facebook/blenderbot-400M-distill
|
||||||
|
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||||
|
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "blenderbot",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "I like you!"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="classification" label="Text Classification">
|
||||||
|
|
||||||
|
Append `text-classification` to the model name
|
||||||
|
|
||||||
|
e.g. `huggingface/text-classification/<model-name>`
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
# [OPTIONAL] set env var
|
||||||
|
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
|
||||||
|
|
||||||
|
messages = [{ "content": "I like you, I love you!","role": "user"}]
|
||||||
|
|
||||||
|
# e.g. Call 'shahrukhx01/question-vs-statement-classifier' hosted on HF Inference endpoints
|
||||||
|
response = completion(
|
||||||
|
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
|
||||||
|
messages=messages,
|
||||||
|
api_base="https://my-endpoint.endpoints.huggingface.cloud",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: bert-classifier
|
||||||
|
litellm_params:
|
||||||
|
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
|
||||||
|
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||||
|
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "bert-classifier",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "I like you!"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="none" label="Text Generation (NOT TGI)">
|
||||||
|
|
||||||
|
Append `text-generation` to the model name
|
||||||
|
|
||||||
|
e.g. `huggingface/text-generation/<model-name>`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
|
@ -75,7 +245,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
|
||||||
|
|
||||||
# e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints
|
# e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints
|
||||||
response = completion(
|
response = completion(
|
||||||
model="huggingface/roneneldan/TinyStories-3M",
|
model="huggingface/text-generation/roneneldan/TinyStories-3M",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
|
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
|
||||||
)
|
)
|
||||||
|
|
|
@ -6,10 +6,12 @@ import httpx, requests
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
import time
|
import time
|
||||||
import litellm
|
import litellm
|
||||||
from typing import Callable, Dict, List, Any
|
from typing import Callable, Dict, List, Any, Literal
|
||||||
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
|
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
from litellm.types.completion import ChatCompletionMessageToolCallParam
|
||||||
|
import enum
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceError(Exception):
|
class HuggingfaceError(Exception):
|
||||||
|
@ -39,11 +41,29 @@ class HuggingfaceError(Exception):
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
hf_task_list = [
|
||||||
|
"text-generation-inference",
|
||||||
|
"conversational",
|
||||||
|
"text-classification",
|
||||||
|
"text-generation",
|
||||||
|
]
|
||||||
|
|
||||||
|
hf_tasks = Literal[
|
||||||
|
"text-generation-inference",
|
||||||
|
"conversational",
|
||||||
|
"text-classification",
|
||||||
|
"text-generation",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceConfig:
|
class HuggingfaceConfig:
|
||||||
"""
|
"""
|
||||||
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
|
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
hf_task: Optional[hf_tasks] = (
|
||||||
|
None # litellm-specific param, used to know the api spec to use when calling huggingface api
|
||||||
|
)
|
||||||
best_of: Optional[int] = None
|
best_of: Optional[int] = None
|
||||||
decoder_input_details: Optional[bool] = None
|
decoder_input_details: Optional[bool] = None
|
||||||
details: Optional[bool] = True # enables returning logprobs + best of
|
details: Optional[bool] = True # enables returning logprobs + best of
|
||||||
|
@ -101,6 +121,51 @@ class HuggingfaceConfig:
|
||||||
and v is not None
|
and v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return [
|
||||||
|
"stream",
|
||||||
|
"temperature",
|
||||||
|
"max_tokens",
|
||||||
|
"top_p",
|
||||||
|
"stop",
|
||||||
|
"n",
|
||||||
|
"echo",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self, non_default_params: dict, optional_params: dict
|
||||||
|
) -> dict:
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
||||||
|
if param == "temperature":
|
||||||
|
if value == 0.0 or value == 0:
|
||||||
|
# hugging face exception raised when temp==0
|
||||||
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
||||||
|
value = 0.01
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
if param == "n":
|
||||||
|
optional_params["best_of"] = value
|
||||||
|
optional_params["do_sample"] = (
|
||||||
|
True # Need to sample if you want best of for hf inference endpoints
|
||||||
|
)
|
||||||
|
if param == "stream":
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "stop":
|
||||||
|
optional_params["stop"] = value
|
||||||
|
if param == "max_tokens":
|
||||||
|
# HF TGI raises the following exception when max_new_tokens==0
|
||||||
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
||||||
|
if value == 0:
|
||||||
|
value = 1
|
||||||
|
optional_params["max_new_tokens"] = value
|
||||||
|
if param == "echo":
|
||||||
|
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
||||||
|
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
||||||
|
optional_params["decoder_input_details"] = True
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
def output_parser(generated_text: str):
|
def output_parser(generated_text: str):
|
||||||
"""
|
"""
|
||||||
|
@ -162,16 +227,18 @@ def read_tgi_conv_models():
|
||||||
return set(), set()
|
return set(), set()
|
||||||
|
|
||||||
|
|
||||||
def get_hf_task_for_model(model):
|
def get_hf_task_for_model(model: str) -> hf_tasks:
|
||||||
# read text file, cast it to set
|
# read text file, cast it to set
|
||||||
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
||||||
|
if model.split("/")[0] in hf_task_list:
|
||||||
|
return model.split("/")[0] # type: ignore
|
||||||
tgi_models, conversational_models = read_tgi_conv_models()
|
tgi_models, conversational_models = read_tgi_conv_models()
|
||||||
if model in tgi_models:
|
if model in tgi_models:
|
||||||
return "text-generation-inference"
|
return "text-generation-inference"
|
||||||
elif model in conversational_models:
|
elif model in conversational_models:
|
||||||
return "conversational"
|
return "conversational"
|
||||||
elif "roneneldan/TinyStories" in model:
|
elif "roneneldan/TinyStories" in model:
|
||||||
return None
|
return "text-generation"
|
||||||
else:
|
else:
|
||||||
return "text-generation-inference" # default to tgi
|
return "text-generation-inference" # default to tgi
|
||||||
|
|
||||||
|
@ -202,7 +269,7 @@ class Huggingface(BaseLLM):
|
||||||
self,
|
self,
|
||||||
completion_response,
|
completion_response,
|
||||||
model_response,
|
model_response,
|
||||||
task,
|
task: hf_tasks,
|
||||||
optional_params,
|
optional_params,
|
||||||
encoding,
|
encoding,
|
||||||
input_text,
|
input_text,
|
||||||
|
@ -270,6 +337,10 @@ class Huggingface(BaseLLM):
|
||||||
)
|
)
|
||||||
choices_list.append(choice_obj)
|
choices_list.append(choice_obj)
|
||||||
model_response["choices"].extend(choices_list)
|
model_response["choices"].extend(choices_list)
|
||||||
|
elif task == "text-classification":
|
||||||
|
model_response["choices"][0]["message"]["content"] = json.dumps(
|
||||||
|
completion_response
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if len(completion_response[0]["generated_text"]) > 0:
|
if len(completion_response[0]["generated_text"]) > 0:
|
||||||
model_response["choices"][0]["message"]["content"] = output_parser(
|
model_response["choices"][0]["message"]["content"] = output_parser(
|
||||||
|
@ -333,6 +404,12 @@ class Huggingface(BaseLLM):
|
||||||
try:
|
try:
|
||||||
headers = self.validate_environment(api_key, headers)
|
headers = self.validate_environment(api_key, headers)
|
||||||
task = get_hf_task_for_model(model)
|
task = get_hf_task_for_model(model)
|
||||||
|
## VALIDATE API FORMAT
|
||||||
|
if task is None or not isinstance(task, str) or task not in hf_task_list:
|
||||||
|
raise Exception(
|
||||||
|
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
|
||||||
|
)
|
||||||
|
|
||||||
print_verbose(f"{model}, {task}")
|
print_verbose(f"{model}, {task}")
|
||||||
completion_url = ""
|
completion_url = ""
|
||||||
input_text = ""
|
input_text = ""
|
||||||
|
@ -433,14 +510,15 @@ class Huggingface(BaseLLM):
|
||||||
inference_params.pop("return_full_text")
|
inference_params.pop("return_full_text")
|
||||||
data = {
|
data = {
|
||||||
"inputs": prompt,
|
"inputs": prompt,
|
||||||
"parameters": inference_params,
|
}
|
||||||
"stream": ( # type: ignore
|
if task == "text-generation-inference":
|
||||||
|
data["parameters"] = inference_params
|
||||||
|
data["stream"] = ( # type: ignore
|
||||||
True
|
True
|
||||||
if "stream" in optional_params
|
if "stream" in optional_params
|
||||||
and optional_params["stream"] == True
|
and optional_params["stream"] == True
|
||||||
else False
|
else False
|
||||||
),
|
)
|
||||||
}
|
|
||||||
input_text = prompt
|
input_text = prompt
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
@ -531,10 +609,10 @@ class Huggingface(BaseLLM):
|
||||||
isinstance(completion_response, dict)
|
isinstance(completion_response, dict)
|
||||||
and "error" in completion_response
|
and "error" in completion_response
|
||||||
):
|
):
|
||||||
print_verbose(f"completion error: {completion_response['error']}")
|
print_verbose(f"completion error: {completion_response['error']}") # type: ignore
|
||||||
print_verbose(f"response.status_code: {response.status_code}")
|
print_verbose(f"response.status_code: {response.status_code}")
|
||||||
raise HuggingfaceError(
|
raise HuggingfaceError(
|
||||||
message=completion_response["error"],
|
message=completion_response["error"], # type: ignore
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
)
|
)
|
||||||
return self.convert_to_model_response_object(
|
return self.convert_to_model_response_object(
|
||||||
|
@ -563,7 +641,7 @@ class Huggingface(BaseLLM):
|
||||||
data: dict,
|
data: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
task: str,
|
task: hf_tasks,
|
||||||
encoding: Any,
|
encoding: Any,
|
||||||
input_text: str,
|
input_text: str,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -13,6 +13,7 @@ import litellm
|
||||||
from litellm import embedding, completion, completion_cost, Timeout
|
from litellm import embedding, completion, completion_cost, Timeout
|
||||||
from litellm import RateLimitError
|
from litellm import RateLimitError
|
||||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
# litellm.num_retries=3
|
# litellm.num_retries=3
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
|
@ -1137,7 +1138,7 @@ def test_get_hf_task_for_model():
|
||||||
model = "roneneldan/TinyStories-3M"
|
model = "roneneldan/TinyStories-3M"
|
||||||
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
||||||
print(f"model:{model}, model type: {model_type}")
|
print(f"model:{model}, model type: {model_type}")
|
||||||
assert model_type == None
|
assert model_type == "text-generation"
|
||||||
|
|
||||||
|
|
||||||
# test_get_hf_task_for_model()
|
# test_get_hf_task_for_model()
|
||||||
|
@ -1145,15 +1146,92 @@ def test_get_hf_task_for_model():
|
||||||
# ################### Hugging Face TGI models ########################
|
# ################### Hugging Face TGI models ########################
|
||||||
# # TGI model
|
# # TGI model
|
||||||
# # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b
|
# # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b
|
||||||
def hf_test_completion_tgi():
|
def tgi_mock_post(url, data=None, json=None, headers=None):
|
||||||
# litellm.set_verbose=True
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
mock_response.json.return_value = [
|
||||||
|
{
|
||||||
|
"generated_text": "<|assistant|>\nI'm",
|
||||||
|
"details": {
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"seed": None,
|
||||||
|
"prefill": [],
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 28789,
|
||||||
|
"text": "<",
|
||||||
|
"logprob": -0.025222778,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28766,
|
||||||
|
"text": "|",
|
||||||
|
"logprob": -0.000003695488,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 489,
|
||||||
|
"text": "ass",
|
||||||
|
"logprob": -0.0000019073486,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11143,
|
||||||
|
"text": "istant",
|
||||||
|
"logprob": -0.000002026558,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28766,
|
||||||
|
"text": "|",
|
||||||
|
"logprob": -0.0000015497208,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28767,
|
||||||
|
"text": ">",
|
||||||
|
"logprob": -0.0000011920929,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"text": "\n",
|
||||||
|
"logprob": -0.00009703636,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{"id": 28737, "text": "I", "logprob": -0.1953125, "special": False},
|
||||||
|
{
|
||||||
|
"id": 28742,
|
||||||
|
"text": "'",
|
||||||
|
"logprob": -0.88183594,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28719,
|
||||||
|
"text": "m",
|
||||||
|
"logprob": -0.00032639503,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
|
def test_hf_test_completion_tgi():
|
||||||
|
litellm.set_verbose = True
|
||||||
try:
|
try:
|
||||||
response = completion(
|
with patch("requests.post", side_effect=tgi_mock_post):
|
||||||
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
|
response = completion(
|
||||||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
|
||||||
)
|
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||||
# Add any assertions here to check the response
|
max_tokens=10,
|
||||||
print(response)
|
)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
except litellm.ServiceUnavailableError as e:
|
except litellm.ServiceUnavailableError as e:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1191,6 +1269,40 @@ def hf_test_completion_tgi():
|
||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
# hf_test_completion_none_task()
|
# hf_test_completion_none_task()
|
||||||
|
|
||||||
|
|
||||||
|
def mock_post(url, data=None, json=None, headers=None):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
mock_response.json.return_value = [
|
||||||
|
[
|
||||||
|
{"label": "LABEL_0", "score": 0.9990691542625427},
|
||||||
|
{"label": "LABEL_1", "score": 0.0009308889275416732},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
|
def test_hf_classifier_task():
|
||||||
|
try:
|
||||||
|
with patch("requests.post", side_effect=mock_post):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
user_message = "I like you. I love you"
|
||||||
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
|
response = completion(
|
||||||
|
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
|
assert isinstance(response.choices[0], litellm.Choices)
|
||||||
|
assert response.choices[0].message.content is not None
|
||||||
|
assert isinstance(response.choices[0].message.content, str)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
########################### End of Hugging Face Tests ##############################################
|
########################### End of Hugging Face Tests ##############################################
|
||||||
# def test_completion_hf_api():
|
# def test_completion_hf_api():
|
||||||
# # failing on circle ci commenting out
|
# # failing on circle ci commenting out
|
||||||
|
|
|
@ -3,7 +3,27 @@ from litellm import get_optional_params
|
||||||
|
|
||||||
litellm.add_function_to_prompt = True
|
litellm.add_function_to_prompt = True
|
||||||
optional_params = get_optional_params(
|
optional_params = get_optional_params(
|
||||||
tools= [{'type': 'function', 'function': {'description': 'Get the current weather in a given location', 'name': 'get_current_weather', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state, e.g. San Francisco, CA'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}],
|
model="",
|
||||||
tool_choice= 'auto',
|
tools=[
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_choice="auto",
|
||||||
)
|
)
|
||||||
assert optional_params is not None
|
assert optional_params is not None
|
||||||
|
|
|
@ -86,6 +86,7 @@ def test_azure_optional_params_embeddings():
|
||||||
def test_azure_gpt_optional_params_gpt_vision():
|
def test_azure_gpt_optional_params_gpt_vision():
|
||||||
# for OpenAI, Azure all extra params need to get passed as extra_body to OpenAI python. We assert we actually set extra_body here
|
# for OpenAI, Azure all extra params need to get passed as extra_body to OpenAI python. We assert we actually set extra_body here
|
||||||
optional_params = litellm.utils.get_optional_params(
|
optional_params = litellm.utils.get_optional_params(
|
||||||
|
model="",
|
||||||
user="John",
|
user="John",
|
||||||
custom_llm_provider="azure",
|
custom_llm_provider="azure",
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
|
@ -125,6 +126,7 @@ def test_azure_gpt_optional_params_gpt_vision():
|
||||||
def test_azure_gpt_optional_params_gpt_vision_with_extra_body():
|
def test_azure_gpt_optional_params_gpt_vision_with_extra_body():
|
||||||
# if user passes extra_body, we should not over write it, we should pass it along to OpenAI python
|
# if user passes extra_body, we should not over write it, we should pass it along to OpenAI python
|
||||||
optional_params = litellm.utils.get_optional_params(
|
optional_params = litellm.utils.get_optional_params(
|
||||||
|
model="",
|
||||||
user="John",
|
user="John",
|
||||||
custom_llm_provider="azure",
|
custom_llm_provider="azure",
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
|
@ -167,6 +169,7 @@ def test_azure_gpt_optional_params_gpt_vision_with_extra_body():
|
||||||
|
|
||||||
def test_openai_extra_headers():
|
def test_openai_extra_headers():
|
||||||
optional_params = litellm.utils.get_optional_params(
|
optional_params = litellm.utils.get_optional_params(
|
||||||
|
model="",
|
||||||
user="John",
|
user="John",
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider="openai",
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
|
|
|
@ -4871,6 +4871,7 @@ def get_optional_params_embeddings(
|
||||||
def get_optional_params(
|
def get_optional_params(
|
||||||
# use the openai defaults
|
# use the openai defaults
|
||||||
# https://platform.openai.com/docs/api-reference/chat/create
|
# https://platform.openai.com/docs/api-reference/chat/create
|
||||||
|
model: str,
|
||||||
functions=None,
|
functions=None,
|
||||||
function_call=None,
|
function_call=None,
|
||||||
temperature=None,
|
temperature=None,
|
||||||
|
@ -4884,7 +4885,6 @@ def get_optional_params(
|
||||||
frequency_penalty=None,
|
frequency_penalty=None,
|
||||||
logit_bias=None,
|
logit_bias=None,
|
||||||
user=None,
|
user=None,
|
||||||
model=None,
|
|
||||||
custom_llm_provider="",
|
custom_llm_provider="",
|
||||||
response_format=None,
|
response_format=None,
|
||||||
seed=None,
|
seed=None,
|
||||||
|
@ -4913,7 +4913,7 @@ def get_optional_params(
|
||||||
|
|
||||||
passed_params[k] = v
|
passed_params[k] = v
|
||||||
|
|
||||||
optional_params = {}
|
optional_params: Dict = {}
|
||||||
|
|
||||||
common_auth_dict = litellm.common_cloud_provider_auth_params
|
common_auth_dict = litellm.common_cloud_provider_auth_params
|
||||||
if custom_llm_provider in common_auth_dict["providers"]:
|
if custom_llm_provider in common_auth_dict["providers"]:
|
||||||
|
@ -5187,41 +5187,9 @@ def get_optional_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
optional_params = litellm.HuggingfaceConfig().map_openai_params(
|
||||||
if temperature is not None:
|
non_default_params=non_default_params, optional_params=optional_params
|
||||||
if temperature == 0.0 or temperature == 0:
|
)
|
||||||
# hugging face exception raised when temp==0
|
|
||||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
|
||||||
temperature = 0.01
|
|
||||||
optional_params["temperature"] = temperature
|
|
||||||
if top_p is not None:
|
|
||||||
optional_params["top_p"] = top_p
|
|
||||||
if n is not None:
|
|
||||||
optional_params["best_of"] = n
|
|
||||||
optional_params["do_sample"] = (
|
|
||||||
True # Need to sample if you want best of for hf inference endpoints
|
|
||||||
)
|
|
||||||
if stream is not None:
|
|
||||||
optional_params["stream"] = stream
|
|
||||||
if stop is not None:
|
|
||||||
optional_params["stop"] = stop
|
|
||||||
if max_tokens is not None:
|
|
||||||
# HF TGI raises the following exception when max_new_tokens==0
|
|
||||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
|
||||||
if max_tokens == 0:
|
|
||||||
max_tokens = 1
|
|
||||||
optional_params["max_new_tokens"] = max_tokens
|
|
||||||
if n is not None:
|
|
||||||
optional_params["best_of"] = n
|
|
||||||
if presence_penalty is not None:
|
|
||||||
optional_params["repetition_penalty"] = presence_penalty
|
|
||||||
if "echo" in passed_params:
|
|
||||||
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
|
||||||
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
|
||||||
optional_params["decoder_input_details"] = special_params["echo"]
|
|
||||||
passed_params.pop(
|
|
||||||
"echo", None
|
|
||||||
) # since we handle translating echo, we should not send it to TGI request
|
|
||||||
elif custom_llm_provider == "together_ai":
|
elif custom_llm_provider == "together_ai":
|
||||||
## check if unsupported param passed in
|
## check if unsupported param passed in
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
|
@ -6181,7 +6149,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
||||||
"seed",
|
"seed",
|
||||||
]
|
]
|
||||||
elif custom_llm_provider == "huggingface":
|
elif custom_llm_provider == "huggingface":
|
||||||
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
return litellm.HuggingfaceConfig().get_supported_openai_params()
|
||||||
elif custom_llm_provider == "together_ai":
|
elif custom_llm_provider == "together_ai":
|
||||||
return [
|
return [
|
||||||
"stream",
|
"stream",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue