forked from phoenix/litellm-mirror
fix(completion()): add request_timeout as a param, fix claude error when request_timeout set
This commit is contained in:
parent
a724d4bed2
commit
8120477be4
5 changed files with 18 additions and 8 deletions
|
@ -168,6 +168,7 @@ def completion(
|
||||||
logit_bias: dict = {},
|
logit_bias: dict = {},
|
||||||
user: str = "",
|
user: str = "",
|
||||||
deployment_id = None,
|
deployment_id = None,
|
||||||
|
request_timeout: Optional[int] = None,
|
||||||
# Optional liteLLM function params
|
# Optional liteLLM function params
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
|
@ -220,7 +221,7 @@ def completion(
|
||||||
metadata = kwargs.get('metadata', None)
|
metadata = kwargs.get('metadata', None)
|
||||||
fallbacks = kwargs.get('fallbacks', [])
|
fallbacks = kwargs.get('fallbacks', [])
|
||||||
######## end of unpacking kwargs ###########
|
######## end of unpacking kwargs ###########
|
||||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user"]
|
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout"]
|
||||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "metadata", "fallbacks"]
|
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "metadata", "fallbacks"]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||||
|
@ -260,6 +261,7 @@ def completion(
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
logit_bias=logit_bias,
|
logit_bias=logit_bias,
|
||||||
user=user,
|
user=user,
|
||||||
|
request_timeout=request_timeout,
|
||||||
deployment_id=deployment_id,
|
deployment_id=deployment_id,
|
||||||
# params to identify the model
|
# params to identify the model
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -40,7 +40,7 @@ def test_completion_claude():
|
||||||
try:
|
try:
|
||||||
# test without max tokens
|
# test without max tokens
|
||||||
response = completion(
|
response = completion(
|
||||||
model="claude-instant-1", messages=messages
|
model="claude-instant-1", messages=messages, request_timeout=10,
|
||||||
)
|
)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
print(response)
|
print(response)
|
||||||
|
@ -48,6 +48,8 @@ def test_completion_claude():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
test_completion_claude()
|
||||||
|
|
||||||
def test_completion_claude_max_tokens():
|
def test_completion_claude_max_tokens():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
@ -339,13 +341,13 @@ def test_completion_cohere(): # commenting for now as the cohere endpoint is bei
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
test_completion_cohere()
|
# test_completion_cohere()
|
||||||
|
|
||||||
|
|
||||||
def test_completion_openai():
|
def test_completion_openai():
|
||||||
try:
|
try:
|
||||||
litellm.api_key = os.environ['OPENAI_API_KEY']
|
litellm.api_key = os.environ['OPENAI_API_KEY']
|
||||||
response = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=10)
|
response = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=10, request_timeout=10)
|
||||||
print("This is the response object\n", response)
|
print("This is the response object\n", response)
|
||||||
print("\n\nThis is response ms:", response.response_ms)
|
print("\n\nThis is response ms:", response.response_ms)
|
||||||
|
|
||||||
|
@ -362,7 +364,7 @@ def test_completion_openai():
|
||||||
litellm.api_key = None
|
litellm.api_key = None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# test_completion_openai()
|
test_completion_openai()
|
||||||
|
|
||||||
|
|
||||||
def test_completion_openai_prompt():
|
def test_completion_openai_prompt():
|
||||||
|
@ -1018,7 +1020,7 @@ def test_completion_with_fallbacks():
|
||||||
def test_completion_ai21():
|
def test_completion_ai21():
|
||||||
model_name = "j2-light"
|
model_name = "j2-light"
|
||||||
try:
|
try:
|
||||||
response = completion(model=model_name, messages=messages, max_tokens=100, temperature=0.8, logger_fn=logger_fn)
|
response = completion(model=model_name, messages=messages, max_tokens=100, temperature=0.8)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
print(response)
|
print(response)
|
||||||
print(response.response_ms)
|
print(response.response_ms)
|
||||||
|
|
|
@ -64,6 +64,8 @@ def timeout(timeout_duration: float = 0.0, exception_to_raise=Timeout):
|
||||||
local_timeout_duration = timeout_duration
|
local_timeout_duration = timeout_duration
|
||||||
if "force_timeout" in kwargs:
|
if "force_timeout" in kwargs:
|
||||||
local_timeout_duration = kwargs["force_timeout"]
|
local_timeout_duration = kwargs["force_timeout"]
|
||||||
|
elif "request_timeout" in kwargs and kwargs["request_timeout"] is not None:
|
||||||
|
local_timeout_duration = kwargs["request_timeout"]
|
||||||
try:
|
try:
|
||||||
value = await asyncio.wait_for(
|
value = await asyncio.wait_for(
|
||||||
func(*args, **kwargs), timeout=timeout_duration
|
func(*args, **kwargs), timeout=timeout_duration
|
||||||
|
|
|
@ -947,6 +947,7 @@ def get_optional_params( # use the openai defaults
|
||||||
frequency_penalty=0,
|
frequency_penalty=0,
|
||||||
logit_bias={},
|
logit_bias={},
|
||||||
user="",
|
user="",
|
||||||
|
request_timeout=None,
|
||||||
deployment_id=None,
|
deployment_id=None,
|
||||||
model=None,
|
model=None,
|
||||||
custom_llm_provider="",
|
custom_llm_provider="",
|
||||||
|
@ -971,6 +972,7 @@ def get_optional_params( # use the openai defaults
|
||||||
"logit_bias":{},
|
"logit_bias":{},
|
||||||
"user":"",
|
"user":"",
|
||||||
"deployment_id":None,
|
"deployment_id":None,
|
||||||
|
"request_timeout":None,
|
||||||
"model":None,
|
"model":None,
|
||||||
"custom_llm_provider":"",
|
"custom_llm_provider":"",
|
||||||
}
|
}
|
||||||
|
@ -991,6 +993,8 @@ def get_optional_params( # use the openai defaults
|
||||||
if k not in supported_params:
|
if k not in supported_params:
|
||||||
if k == "n" and n == 1: # langchain sends n=1 as a default value
|
if k == "n" and n == 1: # langchain sends n=1 as a default value
|
||||||
pass
|
pass
|
||||||
|
if k == "request_timeout": # litellm handles request time outs
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
unsupported_params.append(k)
|
unsupported_params.append(k)
|
||||||
if unsupported_params and not litellm.drop_params:
|
if unsupported_params and not litellm.drop_params:
|
||||||
|
@ -1273,7 +1277,7 @@ def get_optional_params( # use the openai defaults
|
||||||
if stream:
|
if stream:
|
||||||
optional_params["stream"] = stream
|
optional_params["stream"] = stream
|
||||||
else: # assume passing in params for openai/azure openai
|
else: # assume passing in params for openai/azure openai
|
||||||
supported_params = ["functions", "function_call", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "deployment_id"]
|
supported_params = ["functions", "function_call", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "deployment_id", "request_timeout"]
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
optional_params = non_default_params
|
optional_params = non_default_params
|
||||||
# if user passed in non-default kwargs for specific providers/models, pass them along
|
# if user passed in non-default kwargs for specific providers/models, pass them along
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "0.1.821"
|
version = "0.1.822"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT License"
|
license = "MIT License"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue