forked from phoenix/litellm-mirror
LiteLLM Minor Fixes and Improvements (09/07/2024) (#5580)
* fix(litellm_logging.py): set completion_start_time_float to end_time_float if none
Fixes https://github.com/BerriAI/litellm/issues/5500
* feat(_init_.py): add new 'openai_text_completion_compatible_providers' list
Fixes https://github.com/BerriAI/litellm/issues/5558
Handles correctly routing fireworks ai calls when done via text completions
* fix: fix linting errors
* fix: fix linting errors
* fix(openai.py): fix exception raised
* fix(openai.py): fix error handling
* fix(_redis.py): allow all supported arguments for redis cluster (#5554)
* Revert "fix(_redis.py): allow all supported arguments for redis cluster (#5554)" (#5583)
This reverts commit f2191ef4cb
.
* fix(router.py): return model alias w/ underlying deployment on router.get_model_list()
Fixes https://github.com/BerriAI/litellm/issues/5524#issuecomment-2336410666
* test: handle flaky tests
---------
Co-authored-by: Jonas Dittrich <58814480+Kakadus@users.noreply.github.com>
This commit is contained in:
parent
c86b333054
commit
4ac66bd843
14 changed files with 101 additions and 34 deletions
|
@ -483,7 +483,12 @@ openai_compatible_providers: List = [
|
||||||
"azure_ai",
|
"azure_ai",
|
||||||
"github",
|
"github",
|
||||||
]
|
]
|
||||||
|
openai_text_completion_compatible_providers: List = (
|
||||||
|
[ # providers that support `/v1/completions`
|
||||||
|
"together_ai",
|
||||||
|
"fireworks_ai",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
# well supported replicate llms
|
# well supported replicate llms
|
||||||
replicate_models: List = [
|
replicate_models: List = [
|
||||||
|
|
|
@ -2329,6 +2329,8 @@ def get_standard_logging_object_payload(
|
||||||
completion_start_time_float = completion_start_time.timestamp()
|
completion_start_time_float = completion_start_time.timestamp()
|
||||||
elif isinstance(completion_start_time, float):
|
elif isinstance(completion_start_time, float):
|
||||||
completion_start_time_float = completion_start_time
|
completion_start_time_float = completion_start_time
|
||||||
|
else:
|
||||||
|
completion_start_time_float = end_time_float
|
||||||
# clean up litellm hidden params
|
# clean up litellm hidden params
|
||||||
clean_hidden_params = StandardLoggingHiddenParams(
|
clean_hidden_params = StandardLoggingHiddenParams(
|
||||||
model_id=None,
|
model_id=None,
|
||||||
|
|
|
@ -1263,6 +1263,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
|
|
||||||
error_headers = getattr(e, "headers", None)
|
error_headers = getattr(e, "headers", None)
|
||||||
if response is not None and hasattr(response, "text"):
|
if response is not None and hasattr(response, "text"):
|
||||||
|
error_headers = getattr(e, "headers", None)
|
||||||
raise OpenAIError(
|
raise OpenAIError(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
message=f"{str(e)}\n\nOriginal Response: {response.text}",
|
message=f"{str(e)}\n\nOriginal Response: {response.text}",
|
||||||
|
@ -1800,12 +1801,11 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
super().completion()
|
super().completion()
|
||||||
exception_mapping_worked = False
|
|
||||||
try:
|
try:
|
||||||
if headers is None:
|
if headers is None:
|
||||||
headers = self.validate_environment(api_key=api_key)
|
headers = self.validate_environment(api_key=api_key)
|
||||||
if model is None or messages is None:
|
if model is None or messages is None:
|
||||||
raise OpenAIError(status_code=422, message=f"Missing model or messages")
|
raise OpenAIError(status_code=422, message="Missing model or messages")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
len(messages) > 0
|
len(messages) > 0
|
||||||
|
|
|
@ -162,11 +162,10 @@ class AzureTextCompletion(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
):
|
):
|
||||||
super().completion()
|
super().completion()
|
||||||
exception_mapping_worked = False
|
|
||||||
try:
|
try:
|
||||||
if model is None or messages is None:
|
if model is None or messages is None:
|
||||||
raise AzureOpenAIError(
|
raise AzureOpenAIError(
|
||||||
status_code=422, message=f"Missing model or messages"
|
status_code=422, message="Missing model or messages"
|
||||||
)
|
)
|
||||||
|
|
||||||
max_retries = optional_params.pop("max_retries", 2)
|
max_retries = optional_params.pop("max_retries", 2)
|
||||||
|
@ -293,7 +292,10 @@ class AzureTextCompletion(BaseLLM):
|
||||||
"api-version", api_version
|
"api-version", api_version
|
||||||
)
|
)
|
||||||
|
|
||||||
response = azure_client.completions.create(**data, timeout=timeout) # type: ignore
|
raw_response = azure_client.completions.with_raw_response.create(
|
||||||
|
**data, timeout=timeout
|
||||||
|
)
|
||||||
|
response = raw_response.parse()
|
||||||
stringified_response = response.model_dump()
|
stringified_response = response.model_dump()
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -380,13 +382,15 @@ class AzureTextCompletion(BaseLLM):
|
||||||
"complete_input_dict": data,
|
"complete_input_dict": data,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response = await azure_client.completions.create(**data, timeout=timeout)
|
raw_response = await azure_client.completions.with_raw_response.create(
|
||||||
|
**data, timeout=timeout
|
||||||
|
)
|
||||||
|
response = raw_response.parse()
|
||||||
return openai_text_completion_config.convert_to_chat_model_response_object(
|
return openai_text_completion_config.convert_to_chat_model_response_object(
|
||||||
response_object=response.model_dump(),
|
response_object=response.model_dump(),
|
||||||
model_response_object=model_response,
|
model_response_object=model_response,
|
||||||
)
|
)
|
||||||
except AzureOpenAIError as e:
|
except AzureOpenAIError as e:
|
||||||
exception_mapping_worked = True
|
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
status_code = getattr(e, "status_code", 500)
|
status_code = getattr(e, "status_code", 500)
|
||||||
|
|
|
@ -1209,6 +1209,9 @@ def completion(
|
||||||
custom_llm_provider == "text-completion-openai"
|
custom_llm_provider == "text-completion-openai"
|
||||||
or "ft:babbage-002" in model
|
or "ft:babbage-002" in model
|
||||||
or "ft:davinci-002" in model # support for finetuned completion models
|
or "ft:davinci-002" in model # support for finetuned completion models
|
||||||
|
or custom_llm_provider
|
||||||
|
in litellm.openai_text_completion_compatible_providers
|
||||||
|
and kwargs.get("text_completion") is True
|
||||||
):
|
):
|
||||||
openai.api_type = "openai"
|
openai.api_type = "openai"
|
||||||
|
|
||||||
|
@ -4099,8 +4102,8 @@ def text_completion(
|
||||||
|
|
||||||
kwargs.pop("prompt", None)
|
kwargs.pop("prompt", None)
|
||||||
|
|
||||||
if (
|
if _model is not None and (
|
||||||
_model is not None and custom_llm_provider == "openai"
|
custom_llm_provider == "openai"
|
||||||
): # for openai compatible endpoints - e.g. vllm, call the native /v1/completions endpoint for text completion calls
|
): # for openai compatible endpoints - e.g. vllm, call the native /v1/completions endpoint for text completion calls
|
||||||
if _model not in litellm.open_ai_chat_completion_models:
|
if _model not in litellm.open_ai_chat_completion_models:
|
||||||
model = "text-completion-openai/" + _model
|
model = "text-completion-openai/" + _model
|
||||||
|
|
|
@ -1,16 +1,9 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: "anthropic/claude-3-5-sonnet-20240620"
|
- model_name: "gpt-turbo"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: anthropic/claude-3-5-sonnet-20240620
|
model: azure/chatgpt-v-2
|
||||||
# api_base: http://0.0.0.0:9000
|
api_key: os.environ/AZURE_API_KEY
|
||||||
- model_name: gpt-3.5-turbo
|
api_base: os.environ/AZURE_API_BASE
|
||||||
litellm_params:
|
|
||||||
model: openai/*
|
|
||||||
|
|
||||||
litellm_settings:
|
router_settings:
|
||||||
success_callback: ["s3"]
|
model_group_alias: {"gpt-4": "gpt-turbo"}
|
||||||
s3_callback_params:
|
|
||||||
s3_bucket_name: litellm-logs # AWS Bucket Name for S3
|
|
||||||
s3_region_name: us-west-2 # AWS Region Name for S3
|
|
||||||
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
|
|
||||||
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
|
|
|
@ -3,7 +3,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import print_verbose
|
from litellm._logging import print_verbose
|
||||||
|
@ -36,6 +36,25 @@ def _clean_endpoint_data(endpoint_data: dict, details: Optional[bool] = True):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_deployments_by_id(
|
||||||
|
model_list: List,
|
||||||
|
) -> List:
|
||||||
|
seen_ids = set()
|
||||||
|
filtered_deployments = []
|
||||||
|
|
||||||
|
for deployment in model_list:
|
||||||
|
_model_info = deployment.get("model_info") or {}
|
||||||
|
_id = _model_info.get("id") or None
|
||||||
|
if _id is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if _id not in seen_ids:
|
||||||
|
seen_ids.add(_id)
|
||||||
|
filtered_deployments.append(deployment)
|
||||||
|
|
||||||
|
return filtered_deployments
|
||||||
|
|
||||||
|
|
||||||
async def _perform_health_check(model_list: list, details: Optional[bool] = True):
|
async def _perform_health_check(model_list: list, details: Optional[bool] = True):
|
||||||
"""
|
"""
|
||||||
Perform a health check for each model in the list.
|
Perform a health check for each model in the list.
|
||||||
|
@ -105,6 +124,9 @@ async def perform_health_check(
|
||||||
_new_model_list = [x for x in model_list if x["model_name"] == model]
|
_new_model_list = [x for x in model_list if x["model_name"] == model]
|
||||||
model_list = _new_model_list
|
model_list = _new_model_list
|
||||||
|
|
||||||
|
model_list = filter_deployments_by_id(
|
||||||
|
model_list=model_list
|
||||||
|
) # filter duplicate deployments (e.g. when model alias'es are used)
|
||||||
healthy_endpoints, unhealthy_endpoints = await _perform_health_check(
|
healthy_endpoints, unhealthy_endpoints = await _perform_health_check(
|
||||||
model_list, details
|
model_list, details
|
||||||
)
|
)
|
||||||
|
|
|
@ -109,7 +109,7 @@ async def add_new_member(
|
||||||
where={"user_id": user_info.user_id}, # type: ignore
|
where={"user_id": user_info.user_id}, # type: ignore
|
||||||
data={"teams": {"push": [team_id]}},
|
data={"teams": {"push": [team_id]}},
|
||||||
)
|
)
|
||||||
|
if _returned_user is not None:
|
||||||
returned_user = LiteLLM_UserTable(**_returned_user.model_dump())
|
returned_user = LiteLLM_UserTable(**_returned_user.model_dump())
|
||||||
elif len(existing_user_row) > 1:
|
elif len(existing_user_row) > 1:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
|
@ -4556,6 +4556,27 @@ class Router:
|
||||||
ids.append(id)
|
ids.append(id)
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
|
def _get_all_deployments(
|
||||||
|
self, model_name: str, model_alias: Optional[str] = None
|
||||||
|
) -> List[DeploymentTypedDict]:
|
||||||
|
"""
|
||||||
|
Return all deployments of a model name
|
||||||
|
|
||||||
|
Used for accurate 'get_model_list'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
returned_models: List[DeploymentTypedDict] = []
|
||||||
|
for model in self.model_list:
|
||||||
|
if model["model_name"] == model_name:
|
||||||
|
if model_alias is not None:
|
||||||
|
alias_model = copy.deepcopy(model)
|
||||||
|
alias_model["model_name"] = model_name
|
||||||
|
returned_models.append(alias_model)
|
||||||
|
else:
|
||||||
|
returned_models.append(model)
|
||||||
|
|
||||||
|
return returned_models
|
||||||
|
|
||||||
def get_model_names(self) -> List[str]:
|
def get_model_names(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Returns all possible model names for router.
|
Returns all possible model names for router.
|
||||||
|
@ -4567,15 +4588,18 @@ class Router:
|
||||||
def get_model_list(
|
def get_model_list(
|
||||||
self, model_name: Optional[str] = None
|
self, model_name: Optional[str] = None
|
||||||
) -> Optional[List[DeploymentTypedDict]]:
|
) -> Optional[List[DeploymentTypedDict]]:
|
||||||
|
"""
|
||||||
|
Includes router model_group_alias'es as well
|
||||||
|
"""
|
||||||
if hasattr(self, "model_list"):
|
if hasattr(self, "model_list"):
|
||||||
returned_models: List[DeploymentTypedDict] = []
|
returned_models: List[DeploymentTypedDict] = []
|
||||||
|
|
||||||
for model_alias, model_value in self.model_group_alias.items():
|
for model_alias, model_value in self.model_group_alias.items():
|
||||||
model_alias_item = DeploymentTypedDict(
|
returned_models.extend(
|
||||||
model_name=model_alias,
|
self._get_all_deployments(
|
||||||
litellm_params=LiteLLMParamsTypedDict(model=model_value),
|
model_name=model_value, model_alias=model_alias
|
||||||
|
)
|
||||||
)
|
)
|
||||||
returned_models.append(model_alias_item)
|
|
||||||
|
|
||||||
if model_name is None:
|
if model_name is None:
|
||||||
returned_models += self.model_list
|
returned_models += self.model_list
|
||||||
|
@ -4583,8 +4607,7 @@ class Router:
|
||||||
return returned_models
|
return returned_models
|
||||||
|
|
||||||
for model in self.model_list:
|
for model in self.model_list:
|
||||||
if model["model_name"] == model_name:
|
returned_models.extend(self._get_all_deployments(model_name=model_name))
|
||||||
returned_models.append(model)
|
|
||||||
|
|
||||||
return returned_models
|
return returned_models
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -626,6 +626,8 @@ async def test_model_function_invoke(model, sync_mode, api_key, api_base):
|
||||||
response = await litellm.acompletion(**data)
|
response = await litellm.acompletion(**data)
|
||||||
|
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
|
except litellm.InternalServerError:
|
||||||
|
pass
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -864,7 +864,7 @@ def _pre_call_utils(
|
||||||
data["messages"] = [{"role": "user", "content": "Hello world"}]
|
data["messages"] = [{"role": "user", "content": "Hello world"}]
|
||||||
if streaming is True:
|
if streaming is True:
|
||||||
data["stream"] = True
|
data["stream"] = True
|
||||||
mapped_target = client.chat.completions.with_raw_response
|
mapped_target = client.chat.completions.with_raw_response # type: ignore
|
||||||
if sync_mode:
|
if sync_mode:
|
||||||
original_function = litellm.completion
|
original_function = litellm.completion
|
||||||
else:
|
else:
|
||||||
|
@ -873,7 +873,7 @@ def _pre_call_utils(
|
||||||
data["prompt"] = "Hello world"
|
data["prompt"] = "Hello world"
|
||||||
if streaming is True:
|
if streaming is True:
|
||||||
data["stream"] = True
|
data["stream"] = True
|
||||||
mapped_target = client.completions.with_raw_response
|
mapped_target = client.completions.with_raw_response # type: ignore
|
||||||
if sync_mode:
|
if sync_mode:
|
||||||
original_function = litellm.text_completion
|
original_function = litellm.text_completion
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -52,6 +52,7 @@ def get_current_weather(location, unit="fahrenheit"):
|
||||||
# "anthropic.claude-3-sonnet-20240229-v1:0",
|
# "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@pytest.mark.flaky(retries=3, delay=1)
|
||||||
def test_aaparallel_function_call(model):
|
def test_aaparallel_function_call(model):
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
|
@ -4239,3 +4239,14 @@ def test_completion_vllm():
|
||||||
mock_call.assert_called_once()
|
mock_call.assert_called_once()
|
||||||
|
|
||||||
assert "hello" in mock_call.call_args.kwargs["extra_body"]
|
assert "hello" in mock_call.call_args.kwargs["extra_body"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_fireworks_ai_multiple_choices():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
response = litellm.text_completion(
|
||||||
|
model="fireworks_ai/llama-v3p1-8b-instruct",
|
||||||
|
prompt=["halo", "hi", "halo", "hi"],
|
||||||
|
)
|
||||||
|
print(response.choices)
|
||||||
|
|
||||||
|
assert len(response.choices) == 4
|
||||||
|
|
|
@ -148,6 +148,7 @@ router_settings:
|
||||||
redis_password: os.environ/REDIS_PASSWORD
|
redis_password: os.environ/REDIS_PASSWORD
|
||||||
redis_port: os.environ/REDIS_PORT
|
redis_port: os.environ/REDIS_PORT
|
||||||
enable_pre_call_checks: true
|
enable_pre_call_checks: true
|
||||||
|
model_group_alias: {"my-special-fake-model-alias-name": "fake-openai-endpoint-3"}
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys
|
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue