LiteLLM Minor Fixes & Improvement (11/14/2024) (#6730)

* fix(ollama.py): fix get model info request

Fixes https://github.com/BerriAI/litellm/issues/6703

* feat(anthropic/chat/transformation.py): support passing user id to anthropic via openai 'user' param

* docs(anthropic.md): document all supported openai params for anthropic

* test: fix tests

* fix: fix tests

* feat(jina_ai/): add rerank support

Closes https://github.com/BerriAI/litellm/issues/6691

* test: handle service unavailable error

* fix(handler.py): refactor together ai rerank call

* test: update test to handle overloaded error

* test: fix test

* Litellm router trace (#6742)

* feat(router.py): add trace_id to parent functions - allows tracking retry/fallbacks

* feat(router.py): log trace id across retry/fallback logic

allows grouping llm logs for the same request

* test: fix tests

* fix: fix test

* fix(transformation.py): only set non-none stop_sequences

* Litellm router disable fallbacks (#6743)

* bump: version 1.52.6 → 1.52.7

* feat(router.py): enable dynamically disabling fallbacks

Allows for enabling/disabling fallbacks per key

* feat(litellm_pre_call_utils.py): support setting 'disable_fallbacks' on litellm key

* test: fix test

* fix(exception_mapping_utils.py): map 'model is overloaded' to internal server error

* test: handle gemini error

* test: fix test

* fix: new run
This commit is contained in:
Krish Dholakia 2024-11-15 01:02:54 +05:30 committed by GitHub
parent 61bd742e72
commit 2bf23b0c7d
35 changed files with 853 additions and 246 deletions

View file

@ -679,9 +679,8 @@ class Router:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_function"] = self._completion
kwargs.get("request_timeout", self.timeout)
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("metadata", {}).update({"model_group": model})
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
response = self.function_with_fallbacks(**kwargs)
return response
except Exception as e:
@ -783,8 +782,7 @@ class Router:
kwargs["stream"] = stream
kwargs["original_function"] = self._acompletion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("metadata", {}).update({"model_group": model})
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
request_priority = kwargs.get("priority") or self.default_priority
@ -948,6 +946,17 @@ class Router:
self.fail_calls[model_name] += 1
raise e
def _update_kwargs_before_fallbacks(self, model: str, kwargs: dict) -> None:
"""
Adds/updates to kwargs:
- num_retries
- litellm_trace_id
- metadata
"""
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("litellm_trace_id", str(uuid.uuid4()))
kwargs.setdefault("metadata", {}).update({"model_group": model})
def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None:
"""
Adds default litellm params to kwargs, if set.
@ -1511,9 +1520,7 @@ class Router:
kwargs["model"] = model
kwargs["file"] = file
kwargs["original_function"] = self._atranscription
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
response = await self.async_function_with_fallbacks(**kwargs)
return response
@ -1688,9 +1695,7 @@ class Router:
kwargs["model"] = model
kwargs["input"] = input
kwargs["original_function"] = self._arerank
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
response = await self.async_function_with_fallbacks(**kwargs)
@ -1839,9 +1844,7 @@ class Router:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._atext_completion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
response = await self.async_function_with_fallbacks(**kwargs)
return response
@ -2112,9 +2115,7 @@ class Router:
kwargs["model"] = model
kwargs["input"] = input
kwargs["original_function"] = self._aembedding
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
response = await self.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
@ -2609,6 +2610,7 @@ class Router:
If it fails after num_retries, fall back to another model group
"""
model_group: Optional[str] = kwargs.get("model")
disable_fallbacks: Optional[bool] = kwargs.pop("disable_fallbacks", False)
fallbacks: Optional[List] = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks: Optional[List] = kwargs.get(
"context_window_fallbacks", self.context_window_fallbacks
@ -2616,6 +2618,7 @@ class Router:
content_policy_fallbacks: Optional[List] = kwargs.get(
"content_policy_fallbacks", self.content_policy_fallbacks
)
try:
self._handle_mock_testing_fallbacks(
kwargs=kwargs,
@ -2635,7 +2638,7 @@ class Router:
original_model_group: Optional[str] = kwargs.get("model") # type: ignore
fallback_failure_exception_str = ""
if original_model_group is None:
if disable_fallbacks is True or original_model_group is None:
raise e
input_kwargs = {