LiteLLM Minor Fixes & Improvements (11/26/2024) (#6913)

* docs(config_settings.md): document all router_settings

* ci(config.yml): add router_settings doc test to ci/cd

* test: debug test on ci/cd

* test: debug ci/cd test

* test: fix test

* fix(team_endpoints.py): skip invalid team object. don't fail `/team/list` call

Causes downstream errors if ui just fails to load team list

* test(base_llm_unit_tests.py): add 'response_format={"type": "text"}' test to base_llm_unit_tests

adds complete coverage for all 'response_format' values to ci/cd

* feat(router.py): support wildcard routes in `get_router_model_info()`

Addresses https://github.com/BerriAI/litellm/issues/6914

* build(model_prices_and_context_window.json): add tpm/rpm limits for all gemini models

Allows for ratelimit tracking for gemini models even with wildcard routing enabled

Addresses https://github.com/BerriAI/litellm/issues/6914

* feat(router.py): add tpm/rpm tracking on success/failure to global_router

Addresses https://github.com/BerriAI/litellm/issues/6914

* feat(router.py): support wildcard routes on router.get_model_group_usage()

* fix(router.py): fix linting error

* fix(router.py): implement get_remaining_tokens_and_requests

Addresses https://github.com/BerriAI/litellm/issues/6914

* fix(router.py): fix linting errors

* test: fix test

* test: fix tests

* docs(config_settings.md): add missing dd env vars to docs

* fix(router.py): check if hidden params is dict
This commit is contained in:
Krish Dholakia 2024-11-28 00:01:38 +05:30 committed by GitHub
parent 5d13302e6b
commit 2d2931a215
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 878 additions and 131 deletions

View file

@ -46,17 +46,22 @@ print(env_keys)
repo_base = "./"
print(os.listdir(repo_base))
docs_path = (
"../../docs/my-website/docs/proxy/config_settings.md" # Path to the documentation
"./docs/my-website/docs/proxy/config_settings.md" # Path to the documentation
)
documented_keys = set()
try:
with open(docs_path, "r", encoding="utf-8") as docs_file:
content = docs_file.read()
print(f"content: {content}")
# Find the section titled "general_settings - Reference"
general_settings_section = re.search(
r"### environment variables - Reference(.*?)###", content, re.DOTALL
r"### environment variables - Reference(.*?)(?=\n###|\Z)",
content,
re.DOTALL | re.MULTILINE,
)
print(f"general_settings_section: {general_settings_section}")
if general_settings_section:
# Extract the table rows, which contain the documented keys
table_content = general_settings_section.group(1)
@ -70,6 +75,7 @@ except Exception as e:
)
print(f"documented_keys: {documented_keys}")
# Compare and find undocumented keys
undocumented_keys = env_keys - documented_keys

View file

@ -0,0 +1,87 @@
import os
import re
import inspect
from typing import Type
import sys
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
def get_init_params(cls: Type) -> list[str]:
"""
Retrieve all parameters supported by the `__init__` method of a given class.
Args:
cls: The class to inspect.
Returns:
A list of parameter names.
"""
if not hasattr(cls, "__init__"):
raise ValueError(
f"The provided class {cls.__name__} does not have an __init__ method."
)
init_method = cls.__init__
argspec = inspect.getfullargspec(init_method)
# The first argument is usually 'self', so we exclude it
return argspec.args[1:] # Exclude 'self'
router_init_params = set(get_init_params(litellm.router.Router))
print(router_init_params)
router_init_params.remove("model_list")
# Parse the documentation to extract documented keys
repo_base = "./"
print(os.listdir(repo_base))
docs_path = (
"./docs/my-website/docs/proxy/config_settings.md" # Path to the documentation
)
# docs_path = (
# "../../docs/my-website/docs/proxy/config_settings.md" # Path to the documentation
# )
documented_keys = set()
try:
with open(docs_path, "r", encoding="utf-8") as docs_file:
content = docs_file.read()
# Find the section titled "general_settings - Reference"
general_settings_section = re.search(
r"### router_settings - Reference(.*?)###", content, re.DOTALL
)
if general_settings_section:
# Extract the table rows, which contain the documented keys
table_content = general_settings_section.group(1)
doc_key_pattern = re.compile(
r"\|\s*([^\|]+?)\s*\|"
) # Capture the key from each row of the table
documented_keys.update(doc_key_pattern.findall(table_content))
except Exception as e:
raise Exception(
f"Error reading documentation: {e}, \n repo base - {os.listdir(repo_base)}"
)
# Compare and find undocumented keys
undocumented_keys = router_init_params - documented_keys
# Print results
print("Keys expected in 'router settings' (found in code):")
for key in sorted(router_init_params):
print(key)
if undocumented_keys:
raise Exception(
f"\nKeys not documented in 'router settings - Reference': {undocumented_keys}"
)
else:
print(
"\nAll keys are documented in 'router settings - Reference'. - {}".format(
router_init_params
)
)

View file

@ -62,7 +62,14 @@ class BaseLLMChatTest(ABC):
response = litellm.completion(**base_completion_call_args, messages=messages)
assert response is not None
def test_json_response_format(self):
@pytest.mark.parametrize(
"response_format",
[
{"type": "json_object"},
{"type": "text"},
],
)
def test_json_response_format(self, response_format):
"""
Test that the JSON response format is supported by the LLM API
"""
@ -83,7 +90,7 @@ class BaseLLMChatTest(ABC):
response = litellm.completion(
**base_completion_call_args,
messages=messages,
response_format={"type": "json_object"},
response_format=response_format,
)
print(response)

View file

@ -102,3 +102,17 @@ def test_get_model_info_ollama_chat():
print(mock_client.call_args.kwargs)
assert mock_client.call_args.kwargs["json"]["name"] == "mistral"
def test_get_model_info_gemini():
"""
Tests if ALL gemini models have 'tpm' and 'rpm' in the model info
"""
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
model_map = litellm.model_cost
for model, info in model_map.items():
if model.startswith("gemini/") and not "gemma" in model:
assert info.get("tpm") is not None, f"{model} does not have tpm"
assert info.get("rpm") is not None, f"{model} does not have rpm"

View file

@ -2115,10 +2115,14 @@ def test_router_get_model_info(model, base_model, llm_provider):
assert deployment is not None
if llm_provider == "openai" or (base_model is not None and llm_provider == "azure"):
router.get_router_model_info(deployment=deployment.to_json())
router.get_router_model_info(
deployment=deployment.to_json(), received_model_name=model
)
else:
try:
router.get_router_model_info(deployment=deployment.to_json())
router.get_router_model_info(
deployment=deployment.to_json(), received_model_name=model
)
pytest.fail("Expected this to raise model not mapped error")
except Exception as e:
if "This model isn't mapped yet" in str(e):

View file

@ -174,3 +174,185 @@ async def test_update_kwargs_before_fallbacks(call_type):
print(mock_client.call_args.kwargs)
assert mock_client.call_args.kwargs["litellm_trace_id"] is not None
def test_router_get_model_info_wildcard_routes():
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
router = Router(
model_list=[
{
"model_name": "gemini/*",
"litellm_params": {"model": "gemini/*"},
"model_info": {"id": 1},
},
]
)
model_info = router.get_router_model_info(
deployment=None, received_model_name="gemini/gemini-1.5-flash", id="1"
)
print(model_info)
assert model_info is not None
assert model_info["tpm"] is not None
assert model_info["rpm"] is not None
@pytest.mark.asyncio
async def test_router_get_model_group_usage_wildcard_routes():
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
router = Router(
model_list=[
{
"model_name": "gemini/*",
"litellm_params": {"model": "gemini/*"},
"model_info": {"id": 1},
},
]
)
resp = await router.acompletion(
model="gemini/gemini-1.5-flash",
messages=[{"role": "user", "content": "Hello, how are you?"}],
mock_response="Hello, I'm good.",
)
print(resp)
await asyncio.sleep(1)
tpm, rpm = await router.get_model_group_usage(model_group="gemini/gemini-1.5-flash")
assert tpm is not None, "tpm is None"
assert rpm is not None, "rpm is None"
@pytest.mark.asyncio
async def test_call_router_callbacks_on_success():
router = Router(
model_list=[
{
"model_name": "gemini/*",
"litellm_params": {"model": "gemini/*"},
"model_info": {"id": 1},
},
]
)
with patch.object(
router.cache, "async_increment_cache", new=AsyncMock()
) as mock_callback:
await router.acompletion(
model="gemini/gemini-1.5-flash",
messages=[{"role": "user", "content": "Hello, how are you?"}],
mock_response="Hello, I'm good.",
)
await asyncio.sleep(1)
assert mock_callback.call_count == 2
assert (
mock_callback.call_args_list[0]
.kwargs["key"]
.startswith("global_router:1:gemini/gemini-1.5-flash:tpm")
)
assert (
mock_callback.call_args_list[1]
.kwargs["key"]
.startswith("global_router:1:gemini/gemini-1.5-flash:rpm")
)
@pytest.mark.asyncio
async def test_call_router_callbacks_on_failure():
router = Router(
model_list=[
{
"model_name": "gemini/*",
"litellm_params": {"model": "gemini/*"},
"model_info": {"id": 1},
},
]
)
with patch.object(
router.cache, "async_increment_cache", new=AsyncMock()
) as mock_callback:
with pytest.raises(litellm.RateLimitError):
await router.acompletion(
model="gemini/gemini-1.5-flash",
messages=[{"role": "user", "content": "Hello, how are you?"}],
mock_response="litellm.RateLimitError",
num_retries=0,
)
await asyncio.sleep(1)
print(mock_callback.call_args_list)
assert mock_callback.call_count == 1
assert (
mock_callback.call_args_list[0]
.kwargs["key"]
.startswith("global_router:1:gemini/gemini-1.5-flash:rpm")
)
@pytest.mark.asyncio
async def test_router_model_group_headers():
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
from litellm.types.utils import OPENAI_RESPONSE_HEADERS
router = Router(
model_list=[
{
"model_name": "gemini/*",
"litellm_params": {"model": "gemini/*"},
"model_info": {"id": 1},
}
]
)
for _ in range(2):
resp = await router.acompletion(
model="gemini/gemini-1.5-flash",
messages=[{"role": "user", "content": "Hello, how are you?"}],
mock_response="Hello, I'm good.",
)
await asyncio.sleep(1)
assert (
resp._hidden_params["additional_headers"]["x-litellm-model-group"]
== "gemini/gemini-1.5-flash"
)
assert "x-ratelimit-remaining-requests" in resp._hidden_params["additional_headers"]
assert "x-ratelimit-remaining-tokens" in resp._hidden_params["additional_headers"]
@pytest.mark.asyncio
async def test_get_remaining_model_group_usage():
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
from litellm.types.utils import OPENAI_RESPONSE_HEADERS
router = Router(
model_list=[
{
"model_name": "gemini/*",
"litellm_params": {"model": "gemini/*"},
"model_info": {"id": 1},
}
]
)
for _ in range(2):
await router.acompletion(
model="gemini/gemini-1.5-flash",
messages=[{"role": "user", "content": "Hello, how are you?"}],
mock_response="Hello, I'm good.",
)
await asyncio.sleep(1)
remaining_usage = await router.get_remaining_model_group_usage(
model_group="gemini/gemini-1.5-flash"
)
assert remaining_usage is not None
assert "x-ratelimit-remaining-requests" in remaining_usage
assert "x-ratelimit-remaining-tokens" in remaining_usage

View file

@ -506,7 +506,7 @@ async def test_router_caching_ttl():
) as mock_client:
await router.acompletion(model=model, messages=messages)
mock_client.assert_called_once()
# mock_client.assert_called_once()
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
print(f"mock_client.call_args.args: {mock_client.call_args.args}")

View file

@ -396,7 +396,8 @@ async def test_deployment_callback_on_success(model_list, sync_mode):
assert tpm_key is not None
def test_deployment_callback_on_failure(model_list):
@pytest.mark.asyncio
async def test_deployment_callback_on_failure(model_list):
"""Test if the '_deployment_callback_on_failure' function is working correctly"""
import time
@ -418,6 +419,18 @@ def test_deployment_callback_on_failure(model_list):
assert isinstance(result, bool)
assert result is False
model_response = router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, how are you?"}],
mock_response="I'm fine, thank you!",
)
result = await router.async_deployment_callback_on_failure(
kwargs=kwargs,
completion_response=model_response,
start_time=time.time(),
end_time=time.time(),
)
def test_log_retry(model_list):
"""Test if the '_log_retry' function is working correctly"""