LiteLLM Minor Fixes and Improvements (09/13/2024) (#5689)

* refactor: cleanup unused variables + fix pyright errors

* feat(health_check.py): Closes https://github.com/BerriAI/litellm/issues/5686

* fix(o1_reasoning.py): add stricter check for o-1 reasoning model

* refactor(mistral/): make it easier to see mistral transformation logic

* fix(openai.py): fix openai o-1 model param mapping

Fixes https://github.com/BerriAI/litellm/issues/5685

* feat(main.py): infer finetuned gemini model from base model

Fixes https://github.com/BerriAI/litellm/issues/5678

* docs(vertex.md): update docs to call finetuned gemini models

* feat(proxy_server.py): allow admin to hide proxy model aliases

Closes https://github.com/BerriAI/litellm/issues/5692

* docs(load_balancing.md): add docs on hiding alias models from proxy config

* fix(base.py): don't raise notimplemented error

* fix(user_api_key_auth.py): fix model max budget check

* fix(router.py): fix elif

* fix(user_api_key_auth.py): don't set team_id to empty str

* fix(team_endpoints.py): fix response type

* test(test_completion.py): handle predibase error

* test(test_proxy_server.py): fix test

* fix(o1_transformation.py): fix max_completion_token mapping

* test(test_image_generation.py): mark flaky test
This commit is contained in:
Krish Dholakia 2024-09-14 10:02:55 -07:00 committed by GitHub
parent 60c5d3ebec
commit 713d762411
35 changed files with 1020 additions and 539 deletions

View file

@ -737,6 +737,7 @@ def completion(
preset_cache_key = kwargs.get("preset_cache_key", None)
hf_model_name = kwargs.get("hf_model_name", None)
supports_system_message = kwargs.get("supports_system_message", None)
base_model = kwargs.get("base_model", None)
### TEXT COMPLETION CALLS ###
text_completion = kwargs.get("text_completion", False)
atext_completion = kwargs.get("atext_completion", False)
@ -782,11 +783,9 @@ def completion(
"top_logprobs",
"extra_headers",
]
litellm_params = (
all_litellm_params # use the external var., used in creating cache key as well.
)
default_params = openai_params + litellm_params
default_params = openai_params + all_litellm_params
litellm_params = {} # used to prevent unbound var errors
non_default_params = {
k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider
@ -973,6 +972,7 @@ def completion(
text_completion=kwargs.get("text_completion"),
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
user_continue_message=kwargs.get("user_continue_message"),
base_model=base_model,
)
logging.update_environment_variables(
model=model,
@ -2123,7 +2123,10 @@ def completion(
timeout=timeout,
client=client,
)
elif "gemini" in model:
elif "gemini" in model or (
litellm_params.get("base_model") is not None
and "gemini" in litellm_params["base_model"]
):
model_response = vertex_chat_completion.completion( # type: ignore
model=model,
messages=messages,
@ -2820,7 +2823,7 @@ def completion_with_retries(*args, **kwargs):
)
num_retries = kwargs.pop("num_retries", 3)
retry_strategy = kwargs.pop("retry_strategy", "constant_retry")
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop("retry_strategy", "constant_retry") # type: ignore
original_function = kwargs.pop("original_function", completion)
if retry_strategy == "constant_retry":
retryer = tenacity.Retrying(
@ -4997,7 +5000,9 @@ def speech(
async def ahealth_check(
model_params: dict,
mode: Optional[
Literal["completion", "embedding", "image_generation", "chat", "batch"]
Literal[
"completion", "embedding", "image_generation", "chat", "batch", "rerank"
]
] = None,
prompt: Optional[str] = None,
input: Optional[List] = None,
@ -5113,6 +5118,12 @@ async def ahealth_check(
model_params["prompt"] = prompt
await litellm.aimage_generation(**model_params)
response = {}
elif mode == "rerank":
model_params.pop("messages", None)
model_params["query"] = prompt
model_params["documents"] = ["my sample text"]
await litellm.arerank(**model_params)
response = {}
elif "*" in model:
from litellm.litellm_core_utils.llm_request_utils import (
pick_cheapest_model_from_llm_provider,