From 70111a7abd055d2611f6b5dc4bcbcb11c1862697 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Mon, 28 Oct 2024 15:05:43 -0700 Subject: [PATCH] Litellm dev 10 26 2024 (#6472) * docs(exception_mapping.md): add missing exception types Fixes https://github.com/Aider-AI/aider/issues/2120#issuecomment-2438971183 * fix(main.py): register custom model pricing with specific key Ensure custom model pricing is registered to the specific model+provider key combination * test: make testing more robust for custom pricing * fix(redis_cache.py): instrument otel logging for sync redis calls ensures complete coverage for all redis cache calls --- docs/my-website/docs/exception_mapping.md | 39 ++++++--- litellm/_service_logger.py | 56 ++++++++++++- litellm/caching/redis_cache.py | 62 ++++++++++++++ litellm/exceptions.py | 18 ++--- litellm/main.py | 18 +---- litellm/proxy/_new_secret_config.yaml | 46 +++-------- litellm/utils.py | 1 + .../test_exception_types.py | 81 +++++++++++++++++++ tests/local_testing/test_router_fallbacks.py | 61 ++++++++++++++ 9 files changed, 310 insertions(+), 72 deletions(-) create mode 100644 tests/documentation_tests/test_exception_types.py diff --git a/docs/my-website/docs/exception_mapping.md b/docs/my-website/docs/exception_mapping.md index 5e6006ebe..13eda5b40 100644 --- a/docs/my-website/docs/exception_mapping.md +++ b/docs/my-website/docs/exception_mapping.md @@ -2,18 +2,33 @@ LiteLLM maps exceptions across all providers to their OpenAI counterparts. -| Status Code | Error Type | -|-------------|--------------------------| -| 400 | BadRequestError | -| 401 | AuthenticationError | -| 403 | PermissionDeniedError | -| 404 | NotFoundError | -| 422 | UnprocessableEntityError | -| 429 | RateLimitError | -| >=500 | InternalServerError | -| N/A | ContextWindowExceededError| -| 400 | ContentPolicyViolationError| -| 500 | APIConnectionError | +All exceptions can be imported from `litellm` - e.g. `from litellm import BadRequestError` + +## LiteLLM Exceptions + +| Status Code | Error Type | Inherits from | Description | +|-------------|--------------------------|---------------|-------------| +| 400 | BadRequestError | openai.BadRequestError | +| 400 | UnsupportedParamsError | litellm.BadRequestError | Raised when unsupported params are passed | +| 400 | ContextWindowExceededError| litellm.BadRequestError | Special error type for context window exceeded error messages - enables context window fallbacks | +| 400 | ContentPolicyViolationError| litellm.BadRequestError | Special error type for content policy violation error messages - enables content policy fallbacks | +| 400 | InvalidRequestError | openai.BadRequestError | Deprecated error, use BadRequestError instead | +| 401 | AuthenticationError | openai.AuthenticationError | +| 403 | PermissionDeniedError | openai.PermissionDeniedError | +| 404 | NotFoundError | openai.NotFoundError | raise when invalid models passed, example gpt-8 | +| 408 | Timeout | openai.APITimeoutError | Raised when a timeout occurs | +| 422 | UnprocessableEntityError | openai.UnprocessableEntityError | +| 429 | RateLimitError | openai.RateLimitError | +| 500 | APIConnectionError | openai.APIConnectionError | If any unmapped error is returned, we return this error | +| 500 | APIError | openai.APIError | Generic 500-status code error | +| 503 | ServiceUnavailableError | openai.APIStatusError | If provider returns a service unavailable error, this error is raised | +| >=500 | InternalServerError | openai.InternalServerError | If any unmapped 500-status code error is returned, this error is raised | +| N/A | APIResponseValidationError | openai.APIResponseValidationError | If Rules are used, and request/response fails a rule, this error is raised | +| N/A | BudgetExceededError | Exception | Raised for proxy, when budget is exceeded | +| N/A | JSONSchemaValidationError | litellm.APIResponseValidationError | Raised when response does not match expected json schema - used if `response_schema` param passed in with `enforce_validation=True` | +| N/A | MockException | Exception | Internal exception, raised by mock_completion class. Do not use directly | +| N/A | OpenAIError | openai.OpenAIError | Deprecated internal exception, inherits from openai.OpenAIError. | + Base case we return APIConnectionError diff --git a/litellm/_service_logger.py b/litellm/_service_logger.py index af191eaa0..0e738561b 100644 --- a/litellm/_service_logger.py +++ b/litellm/_service_logger.py @@ -1,3 +1,4 @@ +import asyncio from datetime import datetime, timedelta from typing import TYPE_CHECKING, Any, Optional, Union @@ -32,14 +33,63 @@ class ServiceLogging(CustomLogger): self.prometheusServicesLogger = PrometheusServicesLogger() def service_success_hook( - self, service: ServiceTypes, duration: float, call_type: str + self, + service: ServiceTypes, + duration: float, + call_type: str, + parent_otel_span: Optional[Span] = None, + start_time: Optional[Union[datetime, float]] = None, + end_time: Optional[Union[float, datetime]] = None, ): """ - [TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy). + Handles both sync and async monitoring by checking for existing event loop. """ + # if service == ServiceTypes.REDIS: + # print(f"SYNC service: {service}, call_type: {call_type}") if self.mock_testing: self.mock_testing_sync_success_hook += 1 + try: + # Try to get the current event loop + loop = asyncio.get_event_loop() + # Check if the loop is running + if loop.is_running(): + # If we're in a running loop, create a task + loop.create_task( + self.async_service_success_hook( + service=service, + duration=duration, + call_type=call_type, + parent_otel_span=parent_otel_span, + start_time=start_time, + end_time=end_time, + ) + ) + else: + # Loop exists but not running, we can use run_until_complete + loop.run_until_complete( + self.async_service_success_hook( + service=service, + duration=duration, + call_type=call_type, + parent_otel_span=parent_otel_span, + start_time=start_time, + end_time=end_time, + ) + ) + except RuntimeError: + # No event loop exists, create a new one and run + asyncio.run( + self.async_service_success_hook( + service=service, + duration=duration, + call_type=call_type, + parent_otel_span=parent_otel_span, + start_time=start_time, + end_time=end_time, + ) + ) + def service_failure_hook( self, service: ServiceTypes, duration: float, error: Exception, call_type: str ): @@ -62,6 +112,8 @@ class ServiceLogging(CustomLogger): """ - For counting if the redis, postgres call is successful """ + # if service == ServiceTypes.REDIS: + # print(f"service: {service}, call_type: {call_type}") if self.mock_testing: self.mock_testing_async_success_hook += 1 diff --git a/litellm/caching/redis_cache.py b/litellm/caching/redis_cache.py index 8604bdad6..0160f2f0c 100644 --- a/litellm/caching/redis_cache.py +++ b/litellm/caching/redis_cache.py @@ -143,7 +143,17 @@ class RedisCache(BaseCache): ) key = self.check_and_fix_namespace(key=key) try: + start_time = time.time() self.redis_client.set(name=key, value=str(value), ex=ttl) + end_time = time.time() + _duration = end_time - start_time + self.service_logger_obj.service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="set_cache", + start_time=start_time, + end_time=end_time, + ) except Exception as e: # NON blocking - notify users Redis is throwing an exception print_verbose( @@ -157,14 +167,44 @@ class RedisCache(BaseCache): start_time = time.time() set_ttl = self.get_ttl(ttl=ttl) try: + start_time = time.time() result: int = _redis_client.incr(name=key, amount=value) # type: ignore + end_time = time.time() + _duration = end_time - start_time + self.service_logger_obj.service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="increment_cache", + start_time=start_time, + end_time=end_time, + ) if set_ttl is not None: # check if key already has ttl, if not -> set ttl + start_time = time.time() current_ttl = _redis_client.ttl(key) + end_time = time.time() + _duration = end_time - start_time + self.service_logger_obj.service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="increment_cache_ttl", + start_time=start_time, + end_time=end_time, + ) if current_ttl == -1: # Key has no expiration + start_time = time.time() _redis_client.expire(key, set_ttl) # type: ignore + end_time = time.time() + _duration = end_time - start_time + self.service_logger_obj.service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="increment_cache_expire", + start_time=start_time, + end_time=end_time, + ) return result except Exception as e: ## LOGGING ## @@ -565,7 +605,17 @@ class RedisCache(BaseCache): try: key = self.check_and_fix_namespace(key=key) print_verbose(f"Get Redis Cache: key: {key}") + start_time = time.time() cached_response = self.redis_client.get(key) + end_time = time.time() + _duration = end_time - start_time + self.service_logger_obj.service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="get_cache", + start_time=start_time, + end_time=end_time, + ) print_verbose( f"Got Redis Cache: key: {key}, cached_response {cached_response}" ) @@ -586,7 +636,17 @@ class RedisCache(BaseCache): for cache_key in key_list: cache_key = self.check_and_fix_namespace(key=cache_key) _keys.append(cache_key) + start_time = time.time() results: List = self.redis_client.mget(keys=_keys) # type: ignore + end_time = time.time() + _duration = end_time - start_time + self.service_logger_obj.service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="batch_get_cache", + start_time=start_time, + end_time=end_time, + ) # Associate the results back with their keys. # 'results' is a list of values corresponding to the order of keys in 'key_list'. @@ -725,6 +785,8 @@ class RedisCache(BaseCache): service=ServiceTypes.REDIS, duration=_duration, call_type="sync_ping", + start_time=start_time, + end_time=end_time, ) return response except Exception as e: diff --git a/litellm/exceptions.py b/litellm/exceptions.py index 423ccd603..fba8a7e58 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -661,13 +661,7 @@ class APIResponseValidationError(openai.APIResponseValidationError): # type: ig return _message -class OpenAIError(openai.OpenAIError): # type: ignore - def __init__(self, original_exception=None): - super().__init__() - self.llm_provider = "openai" - - -class JSONSchemaValidationError(APIError): +class JSONSchemaValidationError(APIResponseValidationError): def __init__( self, model: str, llm_provider: str, raw_response: str, schema: str ) -> None: @@ -678,9 +672,13 @@ class JSONSchemaValidationError(APIError): model, raw_response, schema ) self.message = message - super().__init__( - model=model, message=message, llm_provider=llm_provider, status_code=500 - ) + super().__init__(model=model, message=message, llm_provider=llm_provider) + + +class OpenAIError(openai.OpenAIError): # type: ignore + def __init__(self, original_exception=None): + super().__init__() + self.llm_provider = "openai" class UnsupportedParamsError(BadRequestError): diff --git a/litellm/main.py b/litellm/main.py index f6680f2df..6829de677 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -933,12 +933,7 @@ def completion( # type: ignore # noqa: PLR0915 "input_cost_per_token": input_cost_per_token, "output_cost_per_token": output_cost_per_token, "litellm_provider": custom_llm_provider, - }, - model: { - "input_cost_per_token": input_cost_per_token, - "output_cost_per_token": output_cost_per_token, - "litellm_provider": custom_llm_provider, - }, + } } ) elif ( @@ -951,12 +946,7 @@ def completion( # type: ignore # noqa: PLR0915 "input_cost_per_second": input_cost_per_second, "output_cost_per_second": output_cost_per_second, "litellm_provider": custom_llm_provider, - }, - model: { - "input_cost_per_second": input_cost_per_second, - "output_cost_per_second": output_cost_per_second, - "litellm_provider": custom_llm_provider, - }, + } } ) ### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ### @@ -3331,7 +3321,7 @@ def embedding( # noqa: PLR0915 if input_cost_per_token is not None and output_cost_per_token is not None: litellm.register_model( { - model: { + f"{custom_llm_provider}/{model}": { "input_cost_per_token": input_cost_per_token, "output_cost_per_token": output_cost_per_token, "litellm_provider": custom_llm_provider, @@ -3342,7 +3332,7 @@ def embedding( # noqa: PLR0915 output_cost_per_second = output_cost_per_second or 0.0 litellm.register_model( { - model: { + f"{custom_llm_provider}/{model}": { "input_cost_per_second": input_cost_per_second, "output_cost_per_second": output_cost_per_second, "litellm_provider": custom_llm_provider, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index ee56b3366..ad045adb5 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,15 +1,19 @@ model_list: - - model_name: gpt-4o + - model_name: claude-3-5-sonnet-20240620 litellm_params: - model: openai/fake - api_key: fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app/ + model: claude-3-5-sonnet-20240620 + api_key: os.environ/ANTHROPIC_API_KEY + - model_name: claude-3-5-sonnet-aihubmix + litellm_params: + model: openai/claude-3-5-sonnet-20240620 + input_cost_per_token: 0.000003 # 3$/M + output_cost_per_token: 0.000015 # 15$/M + api_base: "https://exampleopenaiendpoint-production.up.railway.app" + api_key: my-fake-key litellm_settings: - callbacks: ["prometheus", "otel"] - -general_settings: - user_api_key_cache_ttl: 3600 + fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }] + callbacks: ["otel"] router_settings: routing_strategy: latency-based-routing @@ -19,32 +23,6 @@ router_settings: # consider last five minutes of calls for latency calculation ttl: 300 - - # model_group_alias: - # gpt-4o: gpt-4o-128k-2024-05-13 - # gpt-4o-mini: gpt-4o-mini-128k-2024-07-18 - - enable_tag_filtering: True - - # retry call 3 times on each model_name (we don't use fallbacks, so this would be 3 times total) - num_retries: 3 - - # -- cooldown settings -- - # see https://github.com/BerriAI/litellm/blob/main/litellm/router_utils/cooldown_handlers.py#L265 - - # cooldown model if it fails > n calls in a minute. - allowed_fails: 2 - - # (in seconds) how long to cooldown model if fails/min > allowed_fails - cooldown_time: 60 - - allowed_fails_policy: - InternalServerErrorAllowedFails: 1 - RateLimitErrorAllowedFails: 2 - TimeoutErrorAllowedFails: 3 - # -- end cooldown settings -- - - # see https://docs.litellm.ai/docs/proxy/prod#3-use-redis-porthost-password-not-redis_url redis_host: os.environ/REDIS_HOST redis_port: os.environ/REDIS_PORT redis_password: os.environ/REDIS_PASSWORD diff --git a/litellm/utils.py b/litellm/utils.py index deb3ae8c6..5f86fd894 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2003,6 +2003,7 @@ def register_model(model_cost: Union[str, dict]): # noqa: PLR0915 }, } """ + loaded_model_cost = {} if isinstance(model_cost, dict): loaded_model_cost = model_cost diff --git a/tests/documentation_tests/test_exception_types.py b/tests/documentation_tests/test_exception_types.py new file mode 100644 index 000000000..87e128605 --- /dev/null +++ b/tests/documentation_tests/test_exception_types.py @@ -0,0 +1,81 @@ +import os +import sys +import traceback + +from dotenv import load_dotenv + +load_dotenv() +import io +import re + +# Backup the original sys.path +original_sys_path = sys.path.copy() + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm + +public_exceptions = litellm.LITELLM_EXCEPTION_TYPES +# Regular expression to extract the error name +error_name_pattern = re.compile(r"\.exceptions\.([A-Za-z]+Error)") + +# Extract error names from each item +error_names = { + error_name_pattern.search(str(item)).group(1) + for item in public_exceptions + if error_name_pattern.search(str(item)) +} + + +# sys.path = original_sys_path + + +# Parse the documentation to extract documented keys +# repo_base = "./" +repo_base = "../../" +print(os.listdir(repo_base)) +docs_path = f"{repo_base}/docs/my-website/docs/exception_mapping.md" # Path to the documentation +documented_keys = set() +try: + with open(docs_path, "r", encoding="utf-8") as docs_file: + content = docs_file.read() + + exceptions_section = re.search( + r"## LiteLLM Exceptions(.*?)\n##", content, re.DOTALL + ) + if exceptions_section: + # Step 2: Extract the table content + table_content = exceptions_section.group(1) + + # Step 3: Create a pattern to capture the Error Types from each row + error_type_pattern = re.compile(r"\|\s*[^|]+\s*\|\s*([^\|]+?)\s*\|") + + # Extract the error types + exceptions = error_type_pattern.findall(table_content) + print(f"exceptions: {exceptions}") + + # Remove extra spaces if any + exceptions = [exception.strip() for exception in exceptions] + + print(exceptions) + documented_keys.update(exceptions) + +except Exception as e: + raise Exception( + f"Error reading documentation: {e}, \n repo base - {os.listdir(repo_base)}" + ) + +print(documented_keys) +print(public_exceptions) +print(error_names) + +# Compare and find undocumented keys +undocumented_keys = error_names - documented_keys + +if undocumented_keys: + raise Exception( + f"\nKeys not documented in 'LiteLLM Exceptions': {undocumented_keys}" + ) +else: + print("\nAll keys are documented in 'LiteLLM Exceptions'. - {}".format(error_names)) diff --git a/tests/local_testing/test_router_fallbacks.py b/tests/local_testing/test_router_fallbacks.py index 7a693dd03..32bf0f92f 100644 --- a/tests/local_testing/test_router_fallbacks.py +++ b/tests/local_testing/test_router_fallbacks.py @@ -1337,3 +1337,64 @@ async def test_anthropic_streaming_fallbacks(sync_mode): mock_client.assert_called_once() print(chunks) assert len(chunks) > 0 + + +def test_router_fallbacks_with_custom_model_costs(): + """ + Tests prod use-case where a custom model is registered with a different provider + custom costs. + + Goal: make sure custom model doesn't override default model costs. + """ + model_list = [ + { + "model_name": "claude-3-5-sonnet-20240620", + "litellm_params": { + "model": "claude-3-5-sonnet-20240620", + "api_key": os.environ["ANTHROPIC_API_KEY"], + "input_cost_per_token": 30, + "output_cost_per_token": 60, + }, + }, + { + "model_name": "claude-3-5-sonnet-aihubmix", + "litellm_params": { + "model": "openai/claude-3-5-sonnet-20240620", + "input_cost_per_token": 0.000003, # 3$/M + "output_cost_per_token": 0.000015, # 15$/M + "api_base": "https://exampleopenaiendpoint-production.up.railway.app", + "api_key": "my-fake-key", + }, + }, + ] + + router = Router( + model_list=model_list, + fallbacks=[{"claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"]}], + ) + + router.completion( + model="claude-3-5-sonnet-aihubmix", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) + + model_info = litellm.get_model_info(model="claude-3-5-sonnet-20240620") + + print(f"key: {model_info['key']}") + + assert model_info["litellm_provider"] == "anthropic" + + response = router.completion( + model="claude-3-5-sonnet-20240620", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) + + print(f"response_cost: {response._hidden_params['response_cost']}") + + assert response._hidden_params["response_cost"] > 10 + + model_info = litellm.get_model_info(model="claude-3-5-sonnet-20240620") + + print(f"key: {model_info['key']}") + + assert model_info["input_cost_per_token"] == 30 + assert model_info["output_cost_per_token"] == 60