mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(utils.py): handle scenario where model="azure/*" and custom_llm_provider="azure"
Fixes https://github.com/BerriAI/litellm/issues/4912
This commit is contained in:
parent
3ee8ae231c
commit
5d96ff6694
5 changed files with 23 additions and 21 deletions
|
@ -1,8 +1,7 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: "predibase-llama"
|
- model_name: "*"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: "predibase/llama-3-8b-instruct"
|
model: "*"
|
||||||
request_timeout: 1
|
|
||||||
|
|
||||||
litellm_settings:
|
# litellm_settings:
|
||||||
failure_callback: ["langfuse"]
|
# failure_callback: ["langfuse"]
|
||||||
|
|
|
@ -472,11 +472,10 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
|
|
||||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
self.print_verbose(f"Inside Max Parallel Request Failure Hook")
|
self.print_verbose("Inside Max Parallel Request Failure Hook")
|
||||||
global_max_parallel_requests = (
|
_metadata = kwargs["litellm_params"].get("metadata", {}) or {}
|
||||||
kwargs["litellm_params"]
|
global_max_parallel_requests = _metadata.get(
|
||||||
.get("metadata", {})
|
"global_max_parallel_requests", None
|
||||||
.get("global_max_parallel_requests", None)
|
|
||||||
)
|
)
|
||||||
user_api_key = (
|
user_api_key = (
|
||||||
kwargs["litellm_params"].get("metadata", {}).get("user_api_key", None)
|
kwargs["litellm_params"].get("metadata", {}).get("user_api_key", None)
|
||||||
|
|
|
@ -1959,6 +1959,7 @@ class ProxyConfig:
|
||||||
if len(_value) > 0:
|
if len(_value) > 0:
|
||||||
_litellm_params[k] = _value
|
_litellm_params[k] = _value
|
||||||
_litellm_params = LiteLLM_Params(**_litellm_params)
|
_litellm_params = LiteLLM_Params(**_litellm_params)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.error(
|
||||||
f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}"
|
f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}"
|
||||||
|
|
|
@ -304,7 +304,7 @@ class Message(OpenAIObject):
|
||||||
content: Optional[str] = None,
|
content: Optional[str] = None,
|
||||||
role: Literal["assistant"] = "assistant",
|
role: Literal["assistant"] = "assistant",
|
||||||
function_call=None,
|
function_call=None,
|
||||||
tool_calls=None,
|
tool_calls: Optional[list] = None,
|
||||||
**params,
|
**params,
|
||||||
):
|
):
|
||||||
init_values = {
|
init_values = {
|
||||||
|
@ -322,7 +322,7 @@ class Message(OpenAIObject):
|
||||||
)
|
)
|
||||||
for tool_call in tool_calls
|
for tool_call in tool_calls
|
||||||
]
|
]
|
||||||
if tool_calls is not None
|
if tool_calls is not None and len(tool_calls) > 0
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
@ -445,8 +445,6 @@ class Choices(OpenAIObject):
|
||||||
|
|
||||||
|
|
||||||
class Usage(OpenAIObject):
|
class Usage(OpenAIObject):
|
||||||
prompt_cache_hit_tokens: Optional[int] = Field(default=None)
|
|
||||||
prompt_cache_miss_tokens: Optional[int] = Field(default=None)
|
|
||||||
prompt_tokens: Optional[int] = Field(default=None)
|
prompt_tokens: Optional[int] = Field(default=None)
|
||||||
completion_tokens: Optional[int] = Field(default=None)
|
completion_tokens: Optional[int] = Field(default=None)
|
||||||
total_tokens: Optional[int] = Field(default=None)
|
total_tokens: Optional[int] = Field(default=None)
|
||||||
|
@ -456,16 +454,15 @@ class Usage(OpenAIObject):
|
||||||
prompt_tokens: Optional[int] = None,
|
prompt_tokens: Optional[int] = None,
|
||||||
completion_tokens: Optional[int] = None,
|
completion_tokens: Optional[int] = None,
|
||||||
total_tokens: Optional[int] = None,
|
total_tokens: Optional[int] = None,
|
||||||
prompt_cache_hit_tokens: Optional[int] = None,
|
**params,
|
||||||
prompt_cache_miss_tokens: Optional[int] = None,
|
|
||||||
):
|
):
|
||||||
data = {
|
data = {
|
||||||
"prompt_tokens": prompt_tokens,
|
"prompt_tokens": prompt_tokens,
|
||||||
"completion_tokens": completion_tokens,
|
"completion_tokens": completion_tokens,
|
||||||
"total_tokens": total_tokens,
|
"total_tokens": total_tokens,
|
||||||
"prompt_cache_hit_tokens": prompt_cache_hit_tokens,
|
**params,
|
||||||
"prompt_cache_miss_tokens": prompt_cache_miss_tokens,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
|
|
|
@ -4444,6 +4444,11 @@ def get_llm_provider(
|
||||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||||
|
|
||||||
if custom_llm_provider:
|
if custom_llm_provider:
|
||||||
|
if (
|
||||||
|
model.split("/")[0] == custom_llm_provider
|
||||||
|
): # handle scenario where model="azure/*" and custom_llm_provider="azure"
|
||||||
|
model = model.replace("{}/".format(custom_llm_provider), "")
|
||||||
|
|
||||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||||
|
|
||||||
if api_key and api_key.startswith("os.environ/"):
|
if api_key and api_key.startswith("os.environ/"):
|
||||||
|
@ -5825,9 +5830,10 @@ def convert_to_model_response_object(
|
||||||
model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore
|
model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore
|
||||||
model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore
|
model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore
|
||||||
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
|
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
|
||||||
model_response_object.usage.prompt_cache_hit_tokens = response_object["usage"].get("prompt_cache_hit_tokens", None) # type: ignore
|
special_keys = ["completion_tokens", "prompt_tokens", "total_tokens"]
|
||||||
model_response_object.usage.prompt_cache_miss_tokens = response_object["usage"].get("prompt_cache_miss_tokens", None) # type: ignore
|
for k, v in response_object["usage"].items():
|
||||||
|
if k not in special_keys:
|
||||||
|
setattr(model_response_object.usage, k, v) # type: ignore
|
||||||
if "created" in response_object:
|
if "created" in response_object:
|
||||||
model_response_object.created = response_object["created"] or int(
|
model_response_object.created = response_object["created"] or int(
|
||||||
time.time()
|
time.time()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue