mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
style(test_completion.py): fix merge conflict
This commit is contained in:
parent
396d9d8e38
commit
dd7e397650
22 changed files with 1535 additions and 250 deletions
|
@ -1,9 +1,10 @@
|
|||
## Uses the huggingface text generation inference API
|
||||
import os, copy
|
||||
import os, copy, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests
|
||||
import time
|
||||
import litellm
|
||||
from typing import Callable
|
||||
from litellm.utils import ModelResponse, Choices, Message
|
||||
from typing import Optional
|
||||
|
@ -17,11 +18,52 @@ class HuggingfaceError(Exception):
|
|||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
# contains any default values we need to pass to the provider
|
||||
HuggingfaceConfig = {
|
||||
"return_full_text": False, # override by setting - completion(..,return_full_text=True)
|
||||
"details": True # needed for getting logprobs etc. for tgi models. override by setting - completion(..., details=False)
|
||||
}
|
||||
class HuggingfaceConfig():
|
||||
"""
|
||||
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
|
||||
"""
|
||||
best_of: Optional[int] = None
|
||||
decoder_input_details: Optional[bool] = None
|
||||
details: Optional[bool] = True # enables returning logprobs + best of
|
||||
max_new_tokens: Optional[int] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
return_full_text: Optional[bool] = False # by default don't return the input as part of the output
|
||||
seed: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
top_n_tokens: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
truncate: Optional[int] = None
|
||||
typical_p: Optional[float] = None
|
||||
watermark: Optional[bool] = None
|
||||
|
||||
def __init__(self,
|
||||
best_of: Optional[int] = None,
|
||||
decoder_input_details: Optional[bool] = None,
|
||||
details: Optional[bool] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
return_full_text: Optional[bool] = None,
|
||||
seed: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_n_tokens: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
truncate: Optional[int] = None,
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: Optional[bool] = None
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
|
||||
def validate_environment(api_key):
|
||||
headers = {
|
||||
|
@ -74,8 +116,10 @@ def get_hf_task_for_model(model):
|
|||
return "text-generation-inference"
|
||||
elif model in conversational_models:
|
||||
return "conversational"
|
||||
else:
|
||||
elif "roneneldan/TinyStories" in model:
|
||||
return None
|
||||
else:
|
||||
return "text-generation-inference" # default to tgi
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
|
@ -108,8 +152,9 @@ def completion(
|
|||
completion_url = f"https://api-inference.huggingface.co/models/{model}"
|
||||
|
||||
## Load Config
|
||||
for k, v in HuggingfaceConfig.items():
|
||||
if k not in optional_params:
|
||||
config=litellm.HuggingfaceConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params: # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
### MAP INPUT PARAMS
|
||||
|
@ -149,19 +194,11 @@ def completion(
|
|||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
if "https://api-inference.huggingface.co/models" in completion_url:
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": inference_params,
|
||||
"stream": True if "stream" in inference_params and inference_params["stream"] == True else False,
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": optional_params,
|
||||
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
|
||||
}
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": optional_params,
|
||||
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
|
||||
}
|
||||
input_text = prompt
|
||||
else:
|
||||
# Non TGI and Conversational llms
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue