From 60709a075364990b79dada917a8250a24919c4e2 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Sat, 14 Sep 2024 10:02:55 -0700 Subject: [PATCH] LiteLLM Minor Fixes and Improvements (09/13/2024) (#5689) * refactor: cleanup unused variables + fix pyright errors * feat(health_check.py): Closes https://github.com/BerriAI/litellm/issues/5686 * fix(o1_reasoning.py): add stricter check for o-1 reasoning model * refactor(mistral/): make it easier to see mistral transformation logic * fix(openai.py): fix openai o-1 model param mapping Fixes https://github.com/BerriAI/litellm/issues/5685 * feat(main.py): infer finetuned gemini model from base model Fixes https://github.com/BerriAI/litellm/issues/5678 * docs(vertex.md): update docs to call finetuned gemini models * feat(proxy_server.py): allow admin to hide proxy model aliases Closes https://github.com/BerriAI/litellm/issues/5692 * docs(load_balancing.md): add docs on hiding alias models from proxy config * fix(base.py): don't raise notimplemented error * fix(user_api_key_auth.py): fix model max budget check * fix(router.py): fix elif * fix(user_api_key_auth.py): don't set team_id to empty str * fix(team_endpoints.py): fix response type * test(test_completion.py): handle predibase error * test(test_proxy_server.py): fix test * fix(o1_transformation.py): fix max_completion_token mapping * test(test_image_generation.py): mark flaky test --- .pre-commit-config.yaml | 6 +- docs/my-website/docs/index.md | 2 +- docs/my-website/docs/providers/vertex.md | 55 +++- docs/my-website/docs/proxy/load_balancing.md | 59 +++- docs/my-website/src/pages/index.md | 2 +- litellm/__init__.py | 7 +- .../SlackAlerting/slack_alerting.py | 47 ++- litellm/llms/OpenAI/gpt_transformation.py | 142 ++++++++ .../{o1_reasoning.py => o1_transformation.py} | 22 +- litellm/llms/OpenAI/openai.py | 221 +++---------- litellm/llms/README.md | 12 + litellm/llms/base.py | 18 +- litellm/llms/mistral/chat.py | 5 + litellm/llms/mistral/embedding.py | 5 + .../mistral/mistral_chat_transformation.py | 126 +++++++ .../mistral_embedding_transformation.py | 0 litellm/main.py | 25 +- litellm/proxy/_new_secret_config.yaml | 10 +- litellm/proxy/_types.py | 68 ++-- litellm/proxy/auth/user_api_key_auth.py | 54 +-- .../management_endpoints/team_endpoints.py | 35 +- litellm/proxy/proxy_server.py | 307 +++++++++--------- litellm/router.py | 148 ++++++--- litellm/router_strategy/tag_based_routing.py | 2 +- .../tests/test_amazing_vertex_completion.py | 41 +++ litellm/tests/test_completion.py | 2 + litellm/tests/test_health_check.py | 18 + litellm/tests/test_image_generation.py | 1 + litellm/tests/test_key_generate_prisma.py | 4 +- litellm/tests/test_optional_params.py | 11 + litellm/tests/test_proxy_server.py | 32 +- litellm/tests/test_router.py | 28 ++ litellm/types/router.py | 5 + litellm/utils.py | 33 +- pyrightconfig.json | 6 + 35 files changed, 1020 insertions(+), 539 deletions(-) create mode 100644 litellm/llms/OpenAI/gpt_transformation.py rename litellm/llms/OpenAI/{o1_reasoning.py => o1_transformation.py} (84%) create mode 100644 litellm/llms/README.md create mode 100644 litellm/llms/mistral/chat.py create mode 100644 litellm/llms/mistral/embedding.py create mode 100644 litellm/llms/mistral/mistral_chat_transformation.py create mode 100644 litellm/llms/mistral/mistral_embedding_transformation.py create mode 100644 pyrightconfig.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a33473b72..4f93569b2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,9 @@ repos: - repo: local hooks: - - id: mypy - name: mypy - entry: python3 -m mypy --ignore-missing-imports + - id: pyright + name: pyright + entry: pyright language: system types: [python] files: ^litellm/ diff --git a/docs/my-website/docs/index.md b/docs/my-website/docs/index.md index fb1097ee1..291c51ab0 100644 --- a/docs/my-website/docs/index.md +++ b/docs/my-website/docs/index.md @@ -5,7 +5,7 @@ import TabItem from '@theme/TabItem'; https://github.com/BerriAI/litellm -## **Call 100+ LLMs using the same Input/Output Format** +## **Call 100+ LLMs using the OpenAI Input/Output Format** - Translate inputs to provider's `completion`, `embedding`, and `image_generation` endpoints - [Consistent output](https://docs.litellm.ai/docs/completion/output), text responses will always be available at `['choices'][0]['message']['content']` diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index ac134d009..baa5e3623 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -1180,9 +1180,58 @@ response = completion( Fine tuned models on vertex have a numerical model/endpoint id. -| Model Name | Function Call | -|------------------|--------------------------------------| -| your fine tuned model | `completion(model='vertex_ai/4965075652664360960', messages)`| + + + +```python +from litellm import completion +import os + +## set ENV variables +os.environ["VERTEXAI_PROJECT"] = "hardy-device-38811" +os.environ["VERTEXAI_LOCATION"] = "us-central1" + +response = completion( + model="vertex_ai/", # e.g. vertex_ai/4965075652664360960 + messages=[{ "content": "Hello, how are you?","role": "user"}], + base_model="vertex_ai/gemini-1.5-pro" # the base model - used for routing +) +``` + + + + +1. Add Vertex Credentials to your env + +```bash +!gcloud auth application-default login +``` + +2. Setup config.yaml + +```yaml +- model_name: finetuned-gemini + litellm_params: + model: vertex_ai/ + vertex_project: + vertex_location: + model_info: + base_model: vertex_ai/gemini-1.5-pro # IMPORTANT +``` + +3. Test it! + +```bash +curl --location 'https://0.0.0.0:4000/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--header 'Authorization: ' \ +--data '{"model": "finetuned-gemini" ,"messages":[{"role": "user", "content":[{"type": "text", "text": "hi"}]}]}' +``` + + + + + ## Gemini Pro Vision | Model Name | Function Call | diff --git a/docs/my-website/docs/proxy/load_balancing.md b/docs/my-website/docs/proxy/load_balancing.md index ff3a351c6..20b803777 100644 --- a/docs/my-website/docs/proxy/load_balancing.md +++ b/docs/my-website/docs/proxy/load_balancing.md @@ -38,9 +38,24 @@ router_settings: ## Router settings on config - routing_strategy, model_group_alias +Expose an 'alias' for a 'model_name' on the proxy server. + +``` +model_group_alias: { + "gpt-4": "gpt-3.5-turbo" +} +``` + +These aliases are shown on `/v1/models`, `/v1/model/info`, and `/v1/model_group/info` by default. + litellm.Router() settings can be set under `router_settings`. You can set `model_group_alias`, `routing_strategy`, `num_retries`,`timeout` . See all Router supported params [here](https://github.com/BerriAI/litellm/blob/1b942568897a48f014fa44618ec3ce54d7570a46/litellm/router.py#L64) + + +### Usage + Example config with `router_settings` + ```yaml model_list: - model_name: gpt-3.5-turbo @@ -48,19 +63,41 @@ model_list: model: azure/ api_base: api_key: - rpm: 6 # Rate limit for this deployment: in requests per minute (rpm) + +router_settings: + model_group_alias: {"gpt-4": "gpt-3.5-turbo"} # all requests with `gpt-4` will be routed to models +``` + +### Hide Alias Models + +Use this if you want to set-up aliases for: + +1. typos +2. minor model version changes +3. case sensitive changes between updates + +```yaml +model_list: - model_name: gpt-3.5-turbo litellm_params: - model: azure/gpt-turbo-small-ca - api_base: https://my-endpoint-canada-berri992.openai.azure.com/ + model: azure/ + api_base: api_key: - rpm: 6 + router_settings: - model_group_alias: {"gpt-4": "gpt-3.5-turbo"} # all requests with `gpt-4` will be routed to models with `gpt-3.5-turbo` - routing_strategy: least-busy # Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] - num_retries: 2 - timeout: 30 # 30 seconds - redis_host: - redis_password: - redis_port: 1992 + model_group_alias: + "GPT-3.5-turbo": # alias + model: "gpt-3.5-turbo" # Actual model name in 'model_list' + hidden: true # Exclude from `/v1/models`, `/v1/model/info`, `/v1/model_group/info` +``` + +### Complete Spec + +```python +model_group_alias: Optional[Dict[str, Union[str, RouterModelGroupAliasItem]]] = {} + + +class RouterModelGroupAliasItem(TypedDict): + model: str + hidden: bool # if 'True', don't return on `/v1/models`, `/v1/model/info`, `/v1/model_group/info` ``` \ No newline at end of file diff --git a/docs/my-website/src/pages/index.md b/docs/my-website/src/pages/index.md index 36d47aedf..d074fd376 100644 --- a/docs/my-website/src/pages/index.md +++ b/docs/my-website/src/pages/index.md @@ -5,7 +5,7 @@ import TabItem from '@theme/TabItem'; https://github.com/BerriAI/litellm -## **Call 100+ LLMs using the same Input/Output Format** +## **Call 100+ LLMs using the OpenAI Input/Output Format** - Translate inputs to provider's `completion`, `embedding`, and `image_generation` endpoints - [Consistent output](https://docs.litellm.ai/docs/completion/output), text responses will always be available at `['choices'][0]['message']['content']` diff --git a/litellm/__init__.py b/litellm/__init__.py index a0347d258..047927dd9 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -939,15 +939,18 @@ from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConf from .llms.OpenAI.openai import ( OpenAIConfig, OpenAITextCompletionConfig, - MistralConfig, MistralEmbeddingConfig, DeepInfraConfig, GroqConfig, AzureAIStudioConfig, ) -from .llms.OpenAI.o1_reasoning import ( +from .llms.mistral.mistral_chat_transformation import MistralConfig +from .llms.OpenAI.o1_transformation import ( OpenAIO1Config, ) +from .llms.OpenAI.gpt_transformation import ( + OpenAIGPTConfig, +) from .llms.nvidia_nim import NvidiaNimConfig from .llms.cerebras.chat import CerebrasConfig from .llms.AI21.chat import AI21ChatConfig diff --git a/litellm/integrations/SlackAlerting/slack_alerting.py b/litellm/integrations/SlackAlerting/slack_alerting.py index a0c141969..38a9171a9 100644 --- a/litellm/integrations/SlackAlerting/slack_alerting.py +++ b/litellm/integrations/SlackAlerting/slack_alerting.py @@ -52,7 +52,9 @@ class SlackAlerting(CustomBatchLogger): def __init__( self, internal_usage_cache: Optional[DualCache] = None, - alerting_threshold: float = 300, # threshold for slow / hanging llm responses (in seconds) + alerting_threshold: Optional[ + float + ] = None, # threshold for slow / hanging llm responses (in seconds) alerting: Optional[List] = [], alert_types: List[AlertType] = [ "llm_exceptions", @@ -74,6 +76,8 @@ class SlackAlerting(CustomBatchLogger): default_webhook_url: Optional[str] = None, **kwargs, ): + if alerting_threshold is None: + alerting_threshold = 300 self.alerting_threshold = alerting_threshold self.alerting = alerting self.alert_types = alert_types @@ -99,6 +103,7 @@ class SlackAlerting(CustomBatchLogger): ): if alerting is not None: self.alerting = alerting + asyncio.create_task(self.periodic_flush()) if alerting_threshold is not None: self.alerting_threshold = alerting_threshold if alert_types is not None: @@ -114,8 +119,6 @@ class SlackAlerting(CustomBatchLogger): if llm_router is not None: self.llm_router = llm_router - asyncio.create_task(self.periodic_flush()) - async def deployment_in_cooldown(self): pass @@ -208,15 +211,20 @@ class SlackAlerting(CustomBatchLogger): _deployment_latencies = metadata["_latency_per_deployment"] if len(_deployment_latencies) == 0: return None + _deployment_latency_map: Optional[dict] = None try: # try sorting deployments by latency _deployment_latencies = sorted( _deployment_latencies.items(), key=lambda x: x[1] ) - _deployment_latencies = dict(_deployment_latencies) - except: + _deployment_latency_map = dict(_deployment_latencies) + except Exception: pass - for api_base, latency in _deployment_latencies.items(): + + if _deployment_latency_map is None: + return + + for api_base, latency in _deployment_latency_map.items(): _message_to_send += f"\n{api_base}: {round(latency,2)}s" _message_to_send = "```" + _message_to_send + "```" return _message_to_send @@ -475,6 +483,7 @@ class SlackAlerting(CustomBatchLogger): ): if self.alerting is None or self.alert_types is None: return + model: str = "" if request_data is not None: model = request_data.get("model", "") messages = request_data.get("messages", None) @@ -619,6 +628,7 @@ class SlackAlerting(CustomBatchLogger): return _id: Optional[str] = "default_id" # used for caching user_info_json = user_info.model_dump(exclude_none=True) + user_info_str = "" for k, v in user_info_json.items(): user_info_str = "\n{}: {}\n".format(k, v) @@ -1475,10 +1485,10 @@ Model Info: if isinstance(response_obj, litellm.ModelResponse) and ( hasattr(response_obj, "usage") - and response_obj.usage is not None - and hasattr(response_obj.usage, "completion_tokens") + and response_obj.usage is not None # type: ignore + and hasattr(response_obj.usage, "completion_tokens") # type: ignore ): - completion_tokens = response_obj.usage.completion_tokens + completion_tokens = response_obj.usage.completion_tokens # type: ignore if completion_tokens is not None and completion_tokens > 0: final_value = float( response_s.total_seconds() / completion_tokens @@ -1608,10 +1618,14 @@ Model Info: todays_date = datetime.datetime.now().date() start_date = todays_date - datetime.timedelta(days=days) - spend_per_team, spend_per_tag = await _get_spend_report_for_time_range( + _resp = await _get_spend_report_for_time_range( start_date=start_date.strftime("%Y-%m-%d"), end_date=todays_date.strftime("%Y-%m-%d"), ) + if _resp is None: + return + + spend_per_team, spend_per_tag = _resp _spend_message = f"*💸 Spend Report for `{start_date.strftime('%m-%d-%Y')} - {todays_date.strftime('%m-%d-%Y')}` ({days} days)*\n" @@ -1656,13 +1670,16 @@ Model Info: days=last_day_of_month - 1 ) - monthly_spend_per_team, monthly_spend_per_tag = ( - await _get_spend_report_for_time_range( - start_date=first_day_of_month.strftime("%Y-%m-%d"), - end_date=last_day_of_month.strftime("%Y-%m-%d"), - ) + _resp = await _get_spend_report_for_time_range( + start_date=first_day_of_month.strftime("%Y-%m-%d"), + end_date=last_day_of_month.strftime("%Y-%m-%d"), ) + if _resp is None: + return + + monthly_spend_per_team, monthly_spend_per_tag = _resp + _spend_message = f"*💸 Monthly Spend Report for `{first_day_of_month.strftime('%m-%d-%Y')} - {last_day_of_month.strftime('%m-%d-%Y')}` *\n" if monthly_spend_per_team is not None: diff --git a/litellm/llms/OpenAI/gpt_transformation.py b/litellm/llms/OpenAI/gpt_transformation.py new file mode 100644 index 000000000..be14031bd --- /dev/null +++ b/litellm/llms/OpenAI/gpt_transformation.py @@ -0,0 +1,142 @@ +""" +Support for gpt model family +""" + +import types +from typing import Optional, Union + +import litellm +from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage + + +class OpenAIGPTConfig: + """ + Reference: https://platform.openai.com/docs/api-reference/chat/create + + The class `OpenAIConfig` provides configuration for the OpenAI's Chat API interface. Below are the parameters: + + - `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition. + + - `function_call` (string or object): This optional parameter controls how the model calls functions. + + - `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs. + + - `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion. + + - `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. + + - `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message. + + - `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics. + + - `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens. + + - `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. + + - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. + """ + + frequency_penalty: Optional[int] = None + function_call: Optional[Union[str, dict]] = None + functions: Optional[list] = None + logit_bias: Optional[dict] = None + max_tokens: Optional[int] = None + n: Optional[int] = None + presence_penalty: Optional[int] = None + stop: Optional[Union[str, list]] = None + temperature: Optional[int] = None + top_p: Optional[int] = None + response_format: Optional[dict] = None + + def __init__( + self, + frequency_penalty: Optional[int] = None, + function_call: Optional[Union[str, dict]] = None, + functions: Optional[list] = None, + logit_bias: Optional[dict] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[int] = None, + stop: Optional[Union[str, list]] = None, + temperature: Optional[int] = None, + top_p: Optional[int] = None, + response_format: Optional[dict] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self, model: str) -> list: + base_params = [ + "frequency_penalty", + "logit_bias", + "logprobs", + "top_logprobs", + "max_tokens", + "n", + "presence_penalty", + "seed", + "stop", + "stream", + "stream_options", + "temperature", + "top_p", + "tools", + "tool_choice", + "function_call", + "functions", + "max_retries", + "extra_headers", + "parallel_tool_calls", + ] # works across all models + + model_specific_params = [] + if ( + model != "gpt-3.5-turbo-16k" and model != "gpt-4" + ): # gpt-4 does not support 'response_format' + model_specific_params.append("response_format") + + if ( + model in litellm.open_ai_chat_completion_models + ) or model in litellm.open_ai_text_completion_models: + model_specific_params.append( + "user" + ) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai + return base_params + model_specific_params + + def _map_openai_params( + self, non_default_params: dict, optional_params: dict, model: str + ) -> dict: + supported_openai_params = self.get_supported_openai_params(model) + for param, value in non_default_params.items(): + if param in supported_openai_params: + optional_params[param] = value + return optional_params + + def map_openai_params( + self, non_default_params: dict, optional_params: dict, model: str + ) -> dict: + return self._map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + ) diff --git a/litellm/llms/OpenAI/o1_reasoning.py b/litellm/llms/OpenAI/o1_transformation.py similarity index 84% rename from litellm/llms/OpenAI/o1_reasoning.py rename to litellm/llms/OpenAI/o1_transformation.py index bcab17660..2184b1f4e 100644 --- a/litellm/llms/OpenAI/o1_reasoning.py +++ b/litellm/llms/OpenAI/o1_transformation.py @@ -17,13 +17,12 @@ from typing import Any, List, Optional, Union import litellm from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage -from .openai import OpenAIConfig +from .gpt_transformation import OpenAIGPTConfig -class OpenAIO1Config(OpenAIConfig): +class OpenAIO1Config(OpenAIGPTConfig): """ Reference: https://platform.openai.com/docs/guides/reasoning - """ @classmethod @@ -50,9 +49,7 @@ class OpenAIO1Config(OpenAIConfig): """ - all_openai_params = litellm.OpenAIConfig().get_supported_openai_params( - model="gpt-4o" - ) + all_openai_params = super().get_supported_openai_params(model=model) non_supported_params = [ "logprobs", "tools", @@ -69,13 +66,14 @@ class OpenAIO1Config(OpenAIConfig): def map_openai_params( self, non_default_params: dict, optional_params: dict, model: str ): - for param, value in non_default_params.items(): - if param == "max_tokens": - optional_params["max_completion_tokens"] = value - return optional_params + if "max_tokens" in non_default_params: + optional_params["max_completion_tokens"] = non_default_params.pop( + "max_tokens" + ) + return super()._map_openai_params(non_default_params, optional_params, model) def is_model_o1_reasoning_model(self, model: str) -> bool: - if "o1" in model: + if model in litellm.open_ai_chat_completion_models and "o1" in model: return True return False @@ -93,7 +91,7 @@ class OpenAIO1Config(OpenAIConfig): ) messages[i] = new_message # Replace the old message with the new one - if isinstance(message["content"], list): + if "content" in message and isinstance(message["content"], list): new_content = [] for content_item in message["content"]: if content_item.get("type") == "image_url": diff --git a/litellm/llms/OpenAI/openai.py b/litellm/llms/OpenAI/openai.py index 89f397032..8504d5fe2 100644 --- a/litellm/llms/OpenAI/openai.py +++ b/litellm/llms/OpenAI/openai.py @@ -60,122 +60,6 @@ class OpenAIError(Exception): ) # Call the base class constructor with the parameters it needs -class MistralConfig: - """ - Reference: https://docs.mistral.ai/api/ - - The class `MistralConfig` provides configuration for the Mistral's Chat API interface. Below are the parameters: - - - `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. API Default - 0.7. - - - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. API Default - 1. - - - `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. API Default - null. - - - `tools` (list or null): A list of available tools for the model. Use this to specify functions for which the model can generate JSON inputs. - - - `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'. - - - `stop` (string or array of strings): Stop generation if this token is detected. Or if one of these tokens is detected when providing an array - - - `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results. - - - `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'. - - - `response_format` (object or null): An object specifying the format that the model must output. Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message. - """ - - temperature: Optional[int] = None - top_p: Optional[int] = None - max_tokens: Optional[int] = None - tools: Optional[list] = None - tool_choice: Optional[Literal["auto", "any", "none"]] = None - random_seed: Optional[int] = None - safe_prompt: Optional[bool] = None - response_format: Optional[dict] = None - stop: Optional[Union[str, list]] = None - - def __init__( - self, - temperature: Optional[int] = None, - top_p: Optional[int] = None, - max_tokens: Optional[int] = None, - tools: Optional[list] = None, - tool_choice: Optional[Literal["auto", "any", "none"]] = None, - random_seed: Optional[int] = None, - safe_prompt: Optional[bool] = None, - response_format: Optional[dict] = None, - stop: Optional[Union[str, list]] = None, - ) -> None: - locals_ = locals().copy() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - def get_supported_openai_params(self): - return [ - "stream", - "temperature", - "top_p", - "max_tokens", - "tools", - "tool_choice", - "seed", - "stop", - "response_format", - ] - - def _map_tool_choice(self, tool_choice: str) -> str: - if tool_choice == "auto" or tool_choice == "none": - return tool_choice - elif tool_choice == "required": - return "any" - else: # openai 'tool_choice' object param not supported by Mistral API - return "any" - - def map_openai_params(self, non_default_params: dict, optional_params: dict): - for param, value in non_default_params.items(): - if param == "max_tokens": - optional_params["max_tokens"] = value - if param == "tools": - optional_params["tools"] = value - if param == "stream" and value is True: - optional_params["stream"] = value - if param == "temperature": - optional_params["temperature"] = value - if param == "top_p": - optional_params["top_p"] = value - if param == "stop": - optional_params["stop"] = value - if param == "tool_choice" and isinstance(value, str): - optional_params["tool_choice"] = self._map_tool_choice( - tool_choice=value - ) - if param == "seed": - optional_params["extra_body"] = {"random_seed": value} - if param == "response_format": - optional_params["response_format"] = value - return optional_params - - class MistralEmbeddingConfig: """ Reference: https://docs.mistral.ai/api/#operation/createEmbedding @@ -526,44 +410,19 @@ class OpenAIConfig: } def get_supported_openai_params(self, model: str) -> list: - base_params = [ - "frequency_penalty", - "logit_bias", - "logprobs", - "top_logprobs", - "max_tokens", - "n", - "presence_penalty", - "seed", - "stop", - "stream", - "stream_options", - "temperature", - "top_p", - "tools", - "tool_choice", - "function_call", - "functions", - "max_retries", - "extra_headers", - "parallel_tool_calls", - ] # works across all models - - model_specific_params = [] if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model): return litellm.OpenAIO1Config().get_supported_openai_params(model=model) - if ( - model != "gpt-3.5-turbo-16k" and model != "gpt-4" - ): # gpt-4 does not support 'response_format' - model_specific_params.append("response_format") + else: + return litellm.OpenAIGPTConfig().get_supported_openai_params(model=model) - if ( - model in litellm.open_ai_chat_completion_models - ) or model in litellm.open_ai_text_completion_models: - model_specific_params.append( - "user" - ) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai - return base_params + model_specific_params + def _map_openai_params( + self, non_default_params: dict, optional_params: dict, model: str + ) -> dict: + supported_openai_params = self.get_supported_openai_params(model) + for param, value in non_default_params.items(): + if param in supported_openai_params: + optional_params[param] = value + return optional_params def map_openai_params( self, non_default_params: dict, optional_params: dict, model: str @@ -575,11 +434,11 @@ class OpenAIConfig: optional_params=optional_params, model=model, ) - supported_openai_params = self.get_supported_openai_params(model) - for param, value in non_default_params.items(): - if param in supported_openai_params: - optional_params[param] = value - return optional_params + return litellm.OpenAIGPTConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + ) class OpenAITextCompletionConfig: @@ -816,18 +675,18 @@ class OpenAIChatCompletion(BaseLLM): except Exception as e: raise e - def completion( + def completion( # type: ignore self, model_response: ModelResponse, timeout: Union[float, httpx.Timeout], optional_params: dict, + logging_obj: Any, model: Optional[str] = None, messages: Optional[list] = None, print_verbose: Optional[Callable] = None, api_key: Optional[str] = None, api_base: Optional[str] = None, acompletion: bool = False, - logging_obj=None, litellm_params=None, logger_fn=None, headers: Optional[dict] = None, @@ -858,14 +717,14 @@ class OpenAIChatCompletion(BaseLLM): # process all OpenAI compatible provider logic here if custom_llm_provider == "mistral": # check if message content passed in as list, and not string - messages = prompt_factory( + messages = prompt_factory( # type: ignore model=model, messages=messages, custom_llm_provider=custom_llm_provider, ) if custom_llm_provider == "perplexity" and messages is not None: # check if messages.name is passed + supported, if not supported remove - messages = prompt_factory( + messages = prompt_factory( # type: ignore model=model, messages=messages, custom_llm_provider=custom_llm_provider, @@ -933,7 +792,7 @@ class OpenAIChatCompletion(BaseLLM): status_code=422, message="max retries must be an int" ) - openai_client = self._get_openai_client( + openai_client: OpenAI = self._get_openai_client( # type: ignore is_async=False, api_key=api_key, api_base=api_base, @@ -1068,7 +927,7 @@ class OpenAIChatCompletion(BaseLLM): 2 ): # if call fails due to alternating messages, retry with reformatted message try: - openai_aclient = self._get_openai_client( + openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore is_async=True, api_key=api_key, api_base=api_base, @@ -1156,7 +1015,7 @@ class OpenAIChatCompletion(BaseLLM): max_retries=None, headers=None, ): - openai_client = self._get_openai_client( + openai_client: OpenAI = self._get_openai_client( # type: ignore is_async=False, api_key=api_key, api_base=api_base, @@ -1210,7 +1069,7 @@ class OpenAIChatCompletion(BaseLLM): response = None for _ in range(2): try: - openai_aclient = self._get_openai_client( + openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore is_async=True, api_key=api_key, api_base=api_base, @@ -1282,7 +1141,7 @@ class OpenAIChatCompletion(BaseLLM): error_headers = getattr(e, "headers", None) raise OpenAIError( status_code=500, - message=f"{str(e)}\n\nOriginal Response: {response.text}", + message=f"{str(e)}\n\nOriginal Response: {response.text}", # type: ignore headers=error_headers, ) else: @@ -1294,7 +1153,7 @@ class OpenAIChatCompletion(BaseLLM): ) elif hasattr(e, "status_code"): raise OpenAIError( - status_code=e.status_code, + status_code=getattr(e, "status_code", 500), message=str(e), headers=error_headers, ) @@ -1361,7 +1220,7 @@ class OpenAIChatCompletion(BaseLLM): ): response = None try: - openai_aclient = self._get_openai_client( + openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore is_async=True, api_key=api_key, api_base=api_base, @@ -1410,16 +1269,16 @@ class OpenAIChatCompletion(BaseLLM): status_code=status_code, message=str(e), headers=error_headers ) - def embedding( + def embedding( # type: ignore self, model: str, input: list, timeout: float, logging_obj, model_response: litellm.utils.EmbeddingResponse, + optional_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, - optional_params=None, client=None, aembedding=None, ): @@ -1452,7 +1311,7 @@ class OpenAIChatCompletion(BaseLLM): ) return response - openai_client = self._get_openai_client( + openai_client: OpenAI = self._get_openai_client( # type: ignore is_async=False, api_key=api_key, api_base=api_base, @@ -1496,11 +1355,11 @@ class OpenAIChatCompletion(BaseLLM): data: dict, model_response: ModelResponse, timeout: float, + logging_obj: Any, api_key: Optional[str] = None, api_base: Optional[str] = None, client=None, max_retries=None, - logging_obj=None, ): response = None try: @@ -1538,15 +1397,16 @@ class OpenAIChatCompletion(BaseLLM): model: Optional[str], prompt: str, timeout: float, + optional_params: dict, + logging_obj: Any, api_key: Optional[str] = None, api_base: Optional[str] = None, model_response: Optional[litellm.utils.ImageResponse] = None, - logging_obj=None, - optional_params=None, client=None, aimg_generation=None, ): exception_mapping_worked = False + data = {} try: model = model data = {"model": model, "prompt": prompt, **optional_params} @@ -1611,7 +1471,9 @@ class OpenAIChatCompletion(BaseLLM): original_response=str(e), ) if hasattr(e, "status_code"): - raise OpenAIError(status_code=e.status_code, message=str(e)) + raise OpenAIError( + status_code=getattr(e, "status_code", 500), message=str(e) + ) else: raise OpenAIError(status_code=500, message=str(e)) @@ -1661,7 +1523,7 @@ class OpenAIChatCompletion(BaseLLM): input=input, **optional_params, ) - return response + return response # type: ignore async def async_audio_speech( self, @@ -1784,11 +1646,8 @@ class OpenAIChatCompletion(BaseLLM): class OpenAITextCompletion(BaseLLM): - _client_session: httpx.Client - def __init__(self) -> None: super().__init__() - self._client_session = self.create_client_session() def validate_environment(self, api_key): headers = { @@ -1806,10 +1665,10 @@ class OpenAITextCompletion(BaseLLM): messages: list, timeout: float, logging_obj: LiteLLMLoggingObj, + optional_params: dict, print_verbose: Optional[Callable] = None, api_base: Optional[str] = None, acompletion: bool = False, - optional_params=None, litellm_params=None, logger_fn=None, client=None, @@ -1921,7 +1780,7 @@ class OpenAITextCompletion(BaseLLM): api_key: str, model: str, timeout: float, - max_retries=None, + max_retries: int, organization: Optional[str] = None, client=None, ): @@ -2017,9 +1876,9 @@ class OpenAITextCompletion(BaseLLM): model_response: ModelResponse, model: str, timeout: float, + max_retries: int, api_base: Optional[str] = None, client=None, - max_retries=None, organization=None, ): if client is None: diff --git a/litellm/llms/README.md b/litellm/llms/README.md new file mode 100644 index 000000000..7a8136792 --- /dev/null +++ b/litellm/llms/README.md @@ -0,0 +1,12 @@ +## File Structure + +### August 27th, 2024 + +To make it easy to see how calls are transformed for each model/provider: + +we are working on moving all supported litellm providers to a folder structure, where folder name is the supported litellm provider name. + +Each folder will contain a `*_transformation.py` file, which has all the request/response transformation logic, making it easy to see how calls are modified. + +E.g. `cohere/`, `bedrock/`. + \ No newline at end of file diff --git a/litellm/llms/base.py b/litellm/llms/base.py index 08c5e1992..943b10182 100644 --- a/litellm/llms/base.py +++ b/litellm/llms/base.py @@ -66,22 +66,24 @@ class BaseLLM: return _aclient_session def __exit__(self): - if hasattr(self, "_client_session"): + if hasattr(self, "_client_session") and self._client_session is not None: self._client_session.close() async def __aexit__(self, exc_type, exc_val, exc_tb): if hasattr(self, "_aclient_session"): - await self._aclient_session.aclose() + await self._aclient_session.aclose() # type: ignore - def validate_environment(self): # set up the environment required to run the model - pass + def validate_environment( + self, *args, **kwargs + ) -> Optional[Any]: # set up the environment required to run the model + return None def completion( self, *args, **kwargs - ): # logic for parsing in - calling - parsing out model completion calls - pass + ) -> Any: # logic for parsing in - calling - parsing out model completion calls + return None def embedding( self, *args, **kwargs - ): # logic for parsing in - calling - parsing out model embedding calls - pass + ) -> Any: # logic for parsing in - calling - parsing out model embedding calls + return None diff --git a/litellm/llms/mistral/chat.py b/litellm/llms/mistral/chat.py new file mode 100644 index 000000000..fc454038f --- /dev/null +++ b/litellm/llms/mistral/chat.py @@ -0,0 +1,5 @@ +""" +Calls handled in openai/ + +as mistral is an openai-compatible endpoint. +""" diff --git a/litellm/llms/mistral/embedding.py b/litellm/llms/mistral/embedding.py new file mode 100644 index 000000000..fc454038f --- /dev/null +++ b/litellm/llms/mistral/embedding.py @@ -0,0 +1,5 @@ +""" +Calls handled in openai/ + +as mistral is an openai-compatible endpoint. +""" diff --git a/litellm/llms/mistral/mistral_chat_transformation.py b/litellm/llms/mistral/mistral_chat_transformation.py new file mode 100644 index 000000000..c74f26d6d --- /dev/null +++ b/litellm/llms/mistral/mistral_chat_transformation.py @@ -0,0 +1,126 @@ +""" +Transformation logic from OpenAI /v1/chat/completion format to Mistral's /chat/completion format. + +Why separate file? Make it easy to see how transformation works + +Docs - https://docs.mistral.ai/api/ +""" + +import types +from typing import List, Literal, Optional, Union + + +class MistralConfig: + """ + Reference: https://docs.mistral.ai/api/ + + The class `MistralConfig` provides configuration for the Mistral's Chat API interface. Below are the parameters: + + - `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. API Default - 0.7. + + - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. API Default - 1. + + - `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. API Default - null. + + - `tools` (list or null): A list of available tools for the model. Use this to specify functions for which the model can generate JSON inputs. + + - `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'. + + - `stop` (string or array of strings): Stop generation if this token is detected. Or if one of these tokens is detected when providing an array + + - `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results. + + - `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'. + + - `response_format` (object or null): An object specifying the format that the model must output. Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message. + """ + + temperature: Optional[int] = None + top_p: Optional[int] = None + max_tokens: Optional[int] = None + tools: Optional[list] = None + tool_choice: Optional[Literal["auto", "any", "none"]] = None + random_seed: Optional[int] = None + safe_prompt: Optional[bool] = None + response_format: Optional[dict] = None + stop: Optional[Union[str, list]] = None + + def __init__( + self, + temperature: Optional[int] = None, + top_p: Optional[int] = None, + max_tokens: Optional[int] = None, + tools: Optional[list] = None, + tool_choice: Optional[Literal["auto", "any", "none"]] = None, + random_seed: Optional[int] = None, + safe_prompt: Optional[bool] = None, + response_format: Optional[dict] = None, + stop: Optional[Union[str, list]] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self): + return [ + "stream", + "temperature", + "top_p", + "max_tokens", + "tools", + "tool_choice", + "seed", + "stop", + "response_format", + ] + + def _map_tool_choice(self, tool_choice: str) -> str: + if tool_choice == "auto" or tool_choice == "none": + return tool_choice + elif tool_choice == "required": + return "any" + else: # openai 'tool_choice' object param not supported by Mistral API + return "any" + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + if param == "max_tokens": + optional_params["max_tokens"] = value + if param == "tools": + optional_params["tools"] = value + if param == "stream" and value is True: + optional_params["stream"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "stop": + optional_params["stop"] = value + if param == "tool_choice" and isinstance(value, str): + optional_params["tool_choice"] = self._map_tool_choice( + tool_choice=value + ) + if param == "seed": + optional_params["extra_body"] = {"random_seed": value} + if param == "response_format": + optional_params["response_format"] = value + return optional_params diff --git a/litellm/llms/mistral/mistral_embedding_transformation.py b/litellm/llms/mistral/mistral_embedding_transformation.py new file mode 100644 index 000000000..e69de29bb diff --git a/litellm/main.py b/litellm/main.py index 8df5d604d..a50c908c6 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -737,6 +737,7 @@ def completion( preset_cache_key = kwargs.get("preset_cache_key", None) hf_model_name = kwargs.get("hf_model_name", None) supports_system_message = kwargs.get("supports_system_message", None) + base_model = kwargs.get("base_model", None) ### TEXT COMPLETION CALLS ### text_completion = kwargs.get("text_completion", False) atext_completion = kwargs.get("atext_completion", False) @@ -782,11 +783,9 @@ def completion( "top_logprobs", "extra_headers", ] - litellm_params = ( - all_litellm_params # use the external var., used in creating cache key as well. - ) - default_params = openai_params + litellm_params + default_params = openai_params + all_litellm_params + litellm_params = {} # used to prevent unbound var errors non_default_params = { k: v for k, v in kwargs.items() if k not in default_params } # model-specific params - pass them straight to the model/provider @@ -973,6 +972,7 @@ def completion( text_completion=kwargs.get("text_completion"), azure_ad_token_provider=kwargs.get("azure_ad_token_provider"), user_continue_message=kwargs.get("user_continue_message"), + base_model=base_model, ) logging.update_environment_variables( model=model, @@ -2123,7 +2123,10 @@ def completion( timeout=timeout, client=client, ) - elif "gemini" in model: + elif "gemini" in model or ( + litellm_params.get("base_model") is not None + and "gemini" in litellm_params["base_model"] + ): model_response = vertex_chat_completion.completion( # type: ignore model=model, messages=messages, @@ -2820,7 +2823,7 @@ def completion_with_retries(*args, **kwargs): ) num_retries = kwargs.pop("num_retries", 3) - retry_strategy = kwargs.pop("retry_strategy", "constant_retry") + retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop("retry_strategy", "constant_retry") # type: ignore original_function = kwargs.pop("original_function", completion) if retry_strategy == "constant_retry": retryer = tenacity.Retrying( @@ -4997,7 +5000,9 @@ def speech( async def ahealth_check( model_params: dict, mode: Optional[ - Literal["completion", "embedding", "image_generation", "chat", "batch"] + Literal[ + "completion", "embedding", "image_generation", "chat", "batch", "rerank" + ] ] = None, prompt: Optional[str] = None, input: Optional[List] = None, @@ -5113,6 +5118,12 @@ async def ahealth_check( model_params["prompt"] = prompt await litellm.aimage_generation(**model_params) response = {} + elif mode == "rerank": + model_params.pop("messages", None) + model_params["query"] = prompt + model_params["documents"] = ["my sample text"] + await litellm.arerank(**model_params) + response = {} elif "*" in model: from litellm.litellm_core_utils.llm_request_utils import ( pick_cheapest_model_from_llm_provider, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 9a3ce9692..c940e744f 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,9 +1,7 @@ model_list: - - model_name: "gpt-4o" + - model_name: gpt-3.5-turbo litellm_params: - model: gpt-4o + model: gpt-3.5-turbo -litellm_settings: - cache: true - cache_params: - type: local \ No newline at end of file +router_settings: + model_group_alias: {"gpt-4": {"model": "gpt-3.5-turbo", "hidden": false}} \ No newline at end of file diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 662d6d835..db3304745 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -99,15 +99,15 @@ class LitellmUserRoles(str, enum.Enum): return ui_labels.get(self.value, "") -class LitellmTableNames(str, enum.Enum): +class LitellmTableNames(enum.Enum): """ Enum for Table Names used by LiteLLM """ - TEAM_TABLE_NAME: str = "LiteLLM_TeamTable" - USER_TABLE_NAME: str = "LiteLLM_UserTable" - KEY_TABLE_NAME: str = "LiteLLM_VerificationToken" - PROXY_MODEL_TABLE_NAME: str = "LiteLLM_ModelTable" + TEAM_TABLE_NAME = "LiteLLM_TeamTable" + USER_TABLE_NAME = "LiteLLM_UserTable" + KEY_TABLE_NAME = "LiteLLM_VerificationToken" + PROXY_MODEL_TABLE_NAME = "LiteLLM_ModelTable" AlertType = Literal[ @@ -140,7 +140,7 @@ class LiteLLMBase(BaseModel): Implements default functions, all pydantic objects should have. """ - def json(self, **kwargs): + def json(self, **kwargs): # type: ignore try: return self.model_dump(**kwargs) # noqa except Exception as e: @@ -170,7 +170,7 @@ class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase): class LiteLLMRoutes(enum.Enum): - openai_route_names: List = [ + openai_route_names = [ "chat_completion", "completion", "embeddings", @@ -179,7 +179,7 @@ class LiteLLMRoutes(enum.Enum): "moderations", "model_list", # OpenAI /v1/models route ] - openai_routes: List = [ + openai_routes = [ # chat completions "/engines/{model}/chat/completions", "/openai/deployments/{model}/chat/completions", @@ -247,18 +247,18 @@ class LiteLLMRoutes(enum.Enum): "/v1/rerank", ] - mapped_pass_through_routes: List = [ + mapped_pass_through_routes = [ "/bedrock", "/vertex-ai", "/gemini", "/langfuse", ] - anthropic_routes: List = [ + anthropic_routes = [ "/v1/messages", ] - info_routes: List = [ + info_routes = [ "/key/info", "/team/info", "/team/list", @@ -271,9 +271,9 @@ class LiteLLMRoutes(enum.Enum): ] # NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend - master_key_only_routes: List = ["/global/spend/reset", "/key/list"] + master_key_only_routes = ["/global/spend/reset", "/key/list"] - sso_only_routes: List = [ + sso_only_routes = [ "/key/generate", "/key/update", "/key/delete", @@ -282,7 +282,7 @@ class LiteLLMRoutes(enum.Enum): "/sso/get/logout_url", ] - management_routes: List = [ # key + management_routes = [ # key "/key/generate", "/key/update", "/key/delete", @@ -307,7 +307,7 @@ class LiteLLMRoutes(enum.Enum): "/model/info", ] - spend_tracking_routes: List = [ + spend_tracking_routes = [ # spend "/spend/keys", "/spend/users", @@ -316,7 +316,7 @@ class LiteLLMRoutes(enum.Enum): "/spend/logs", ] - global_spend_tracking_routes: List = [ + global_spend_tracking_routes = [ # global spend "/global/spend/logs", "/global/spend", @@ -328,7 +328,7 @@ class LiteLLMRoutes(enum.Enum): "/global/spend/report", ] - public_routes: List = [ + public_routes = [ "/routes", "/", "/health/liveliness", @@ -339,7 +339,7 @@ class LiteLLMRoutes(enum.Enum): "/metrics", ] - internal_user_routes: List = ( + internal_user_routes = ( [ "/key/generate", "/key/update", @@ -357,7 +357,7 @@ class LiteLLMRoutes(enum.Enum): + sso_only_routes ) - self_managed_routes: List = [ + self_managed_routes = [ "/team/member_add", "/team/member_delete", ] # routes that manage their own allowed/disallowed logic @@ -581,7 +581,9 @@ class ModelParams(LiteLLMBase): @classmethod def set_model_info(cls, values): if values.get("model_info") is None: - values.update({"model_info": ModelInfo()}) + values.update( + {"model_info": ModelInfo(id=None, mode="chat", base_model=None)} + ) return values @@ -627,7 +629,7 @@ class GenerateKeyRequest(_GenerateKeyRequest): class GenerateKeyResponse(_GenerateKeyRequest): - key: str + key: str # type: ignore key_name: Optional[str] = None expires: Optional[datetime] user_id: Optional[str] = None @@ -659,7 +661,7 @@ class GenerateKeyResponse(_GenerateKeyRequest): class UpdateKeyRequest(GenerateKeyRequest): # Note: the defaults of all Params here MUST BE NONE # else they will get overwritten - key: str + key: str # type: ignore duration: Optional[str] = None spend: Optional[float] = None metadata: Optional[dict] = None @@ -976,6 +978,7 @@ class TeamCallbackMetadata(LiteLLMBase): class LiteLLM_TeamTable(TeamBase): + team_id: str # type: ignore spend: Optional[float] = None max_parallel_requests: Optional[int] = None budget_duration: Optional[str] = None @@ -1061,7 +1064,7 @@ class LiteLLM_OrganizationTable(LiteLLMBase): class NewOrganizationResponse(LiteLLM_OrganizationTable): - organization_id: str + organization_id: str # type: ignore created_at: datetime updated_at: datetime @@ -1388,16 +1391,7 @@ class UserAPIKeyAuth( """ api_key: Optional[str] = None - user_role: Optional[ - Literal[ - LitellmUserRoles.PROXY_ADMIN, - LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, - LitellmUserRoles.INTERNAL_USER, - LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, - LitellmUserRoles.TEAM, - LitellmUserRoles.CUSTOMER, - ] - ] = None + user_role: Optional[LitellmUserRoles] = None allowed_model_region: Optional[Literal["eu"]] = None parent_otel_span: Optional[Span] = None rpm_limit_per_model: Optional[Dict[str, int]] = None @@ -1716,9 +1710,9 @@ class SpendLogsPayload(TypedDict): total_tokens: int prompt_tokens: int completion_tokens: int - startTime: datetime - endTime: datetime - completionStartTime: Optional[datetime] + startTime: Union[datetime, str] + endTime: Union[datetime, str] + completionStartTime: Optional[Union[datetime, str]] model: str model_id: Optional[str] model_group: Optional[str] @@ -1891,6 +1885,6 @@ class TeamAddMemberResponse(LiteLLM_TeamTable): class TeamInfoResponseObject(TypedDict): team_id: str - team_info: TeamBase + team_info: LiteLLM_TeamTable keys: List team_memberships: List[LiteLLM_TeamMembership] diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index ebe6853ac..b2e63c256 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -109,11 +109,8 @@ async def user_api_key_auth( ), ) -> UserAPIKeyAuth: from litellm.proxy.proxy_server import ( - allowed_routes_check, - common_checks, custom_db_client, general_settings, - get_actual_routes, jwt_handler, litellm_proxy_admin_name, llm_model_list, @@ -125,6 +122,8 @@ async def user_api_key_auth( user_custom_auth, ) + parent_otel_span: Optional[Span] = None + try: route: str = get_request_route(request=request) # get the request body @@ -137,6 +136,7 @@ async def user_api_key_auth( pass_through_endpoints: Optional[List[dict]] = general_settings.get( "pass_through_endpoints", None ) + passed_in_key: Optional[str] = None if isinstance(api_key, str): passed_in_key = api_key api_key = _get_bearer_token(api_key=api_key) @@ -161,7 +161,6 @@ async def user_api_key_auth( custom_litellm_key_header_name=custom_litellm_key_header_name, ) - parent_otel_span: Optional[Span] = None if open_telemetry_logger is not None: parent_otel_span = open_telemetry_logger.tracer.start_span( name="Received Proxy Server Request", @@ -189,7 +188,7 @@ async def user_api_key_auth( ######## Route Checks Before Reading DB / Cache for "token" ################ if ( - route in LiteLLMRoutes.public_routes.value + route in LiteLLMRoutes.public_routes.value # type: ignore or route_in_additonal_public_routes(current_route=route) ): # check if public endpoint @@ -410,7 +409,7 @@ async def user_api_key_auth( #### ELSE #### ## CHECK PASS-THROUGH ENDPOINTS ## is_mapped_pass_through_route: bool = False - for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value: + for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value: # type: ignore if route.startswith(mapped_route): is_mapped_pass_through_route = True if is_mapped_pass_through_route: @@ -444,9 +443,9 @@ async def user_api_key_auth( header_key = headers.get("litellm_user_api_key", "") if ( isinstance(request.headers, dict) - and request.headers.get(key=header_key) is not None + and request.headers.get(key=header_key) is not None # type: ignore ): - api_key = request.headers.get(key=header_key) + api_key = request.headers.get(key=header_key) # type: ignore if master_key is None: if isinstance(api_key, str): @@ -606,7 +605,7 @@ async def user_api_key_auth( ## IF it's not a master key ## Route should not be in master_key_only_routes - if route in LiteLLMRoutes.master_key_only_routes.value: + if route in LiteLLMRoutes.master_key_only_routes.value: # type: ignore raise Exception( f"Tried to access route={route}, which is only for MASTER KEY" ) @@ -669,8 +668,9 @@ async def user_api_key_auth( "allowed_model_region" ) + user_obj: Optional[LiteLLM_UserTable] = None + valid_token_dict: dict = {} if valid_token is not None: - user_obj: Optional[LiteLLM_UserTable] = None # Got Valid Token from Cache, DB # Run checks for # 1. If token can call model @@ -686,6 +686,7 @@ async def user_api_key_auth( # Check 1. If token can call model _model_alias_map = {} + model: Optional[str] = None if ( hasattr(valid_token, "team_model_aliases") and valid_token.team_model_aliases is not None @@ -698,6 +699,7 @@ async def user_api_key_auth( _model_alias_map = {**valid_token.aliases} litellm.model_alias_map = _model_alias_map config = valid_token.config + if config != {}: model_list = config.get("model_list", []) llm_model_list = model_list @@ -887,7 +889,10 @@ async def user_api_key_auth( and max_budget_per_model.get(current_model, None) is not None ): if ( - model_spend[0]["model"] == current_model + "model" in model_spend[0] + and model_spend[0].get("model") == current_model + and "_sum" in model_spend[0] + and "spend" in model_spend[0]["_sum"] and model_spend[0]["_sum"]["spend"] >= max_budget_per_model[current_model] ): @@ -927,16 +932,19 @@ async def user_api_key_auth( ) # Check 8: Additional Common Checks across jwt + key auth - _team_obj = LiteLLM_TeamTable( - team_id=valid_token.team_id, - max_budget=valid_token.team_max_budget, - spend=valid_token.team_spend, - tpm_limit=valid_token.team_tpm_limit, - rpm_limit=valid_token.team_rpm_limit, - blocked=valid_token.team_blocked, - models=valid_token.team_models, - metadata=valid_token.team_metadata, - ) + if valid_token.team_id is not None: + _team_obj: Optional[LiteLLM_TeamTable] = LiteLLM_TeamTable( + team_id=valid_token.team_id, + max_budget=valid_token.team_max_budget, + spend=valid_token.team_spend, + tpm_limit=valid_token.team_tpm_limit, + rpm_limit=valid_token.team_rpm_limit, + blocked=valid_token.team_blocked, + models=valid_token.team_models, + metadata=valid_token.team_metadata, + ) + else: + _team_obj = None user_api_key_cache.set_cache( key=valid_token.team_id, value=_team_obj @@ -1045,7 +1053,7 @@ async def user_api_key_auth( "/global/predict/spend/logs", "/global/activity", "/health/services", - ] + LiteLLMRoutes.info_routes.value + ] + LiteLLMRoutes.info_routes.value # type: ignore # check if the current route startswith any of the allowed routes if ( route is not None @@ -1106,7 +1114,7 @@ async def user_api_key_auth( # Log this exception to OTEL if open_telemetry_logger is not None: - await open_telemetry_logger.async_post_call_failure_hook( + await open_telemetry_logger.async_post_call_failure_hook( # type: ignore original_exception=e, user_api_key_dict=UserAPIKeyAuth(parent_otel_span=parent_otel_span), ) diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index ff13182f7..98263cfed 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -4,10 +4,11 @@ import json import traceback import uuid from datetime import datetime, timedelta, timezone -from typing import List, Optional +from typing import List, Optional, Union import fastapi from fastapi import APIRouter, Depends, Header, HTTPException, Request, status +from pydantic import BaseModel import litellm from litellm._logging import verbose_proxy_logger @@ -28,6 +29,7 @@ from litellm.proxy._types import ( ProxyErrorTypes, ProxyException, TeamAddMemberResponse, + TeamBase, TeamInfoResponseObject, TeamMemberAddRequest, TeamMemberDeleteRequest, @@ -36,6 +38,7 @@ from litellm.proxy._types import ( UpdateTeamRequest, UserAPIKeyAuth, ) +from litellm.proxy.auth.auth_checks import get_team_object from litellm.proxy.auth.user_api_key_auth import _is_user_proxy_admin, user_api_key_auth from litellm.proxy.management_helpers.utils import ( add_new_member, @@ -240,7 +243,7 @@ async def new_team( reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) complete_team_data.budget_reset_at = reset_at - team_row = await prisma_client.insert_data( + team_row: LiteLLM_TeamTable = await prisma_client.insert_data( # type: ignore data=complete_team_data.json(exclude_none=True), table_name="team" ) @@ -462,9 +465,6 @@ async def team_member_add( ``` """ from litellm.proxy.proxy_server import ( - _duration_in_seconds, - create_audit_log_for_update, - get_team_object, litellm_proxy_admin_name, prisma_client, proxy_logging_obj, @@ -932,10 +932,13 @@ async def delete_team( if litellm.store_audit_logs is True: # make an audit log for each team deleted for team_id in data.team_ids: - team_row = await prisma_client.get_data( # type: ignore + team_row: Optional[LiteLLM_TeamTable] = await prisma_client.get_data( # type: ignore team_id=team_id, table_name="team", query_type="find_unique" ) + if team_row is None: + continue + _team_row = team_row.json(exclude_none=True) asyncio.create_task( @@ -1027,8 +1030,10 @@ async def team_info( ), ) - team_info = await prisma_client.get_data( - team_id=team_id, table_name="team", query_type="find_unique" + team_info: Optional[Union[LiteLLM_TeamTable, dict]] = ( + await prisma_client.get_data( + team_id=team_id, table_name="team", query_type="find_unique" + ) ) if team_info is None: raise HTTPException( @@ -1044,6 +1049,9 @@ async def team_info( expires=datetime.now(), ) + if keys is None: + keys = [] + if team_info is None: ## make sure we still return a total spend ## spend = 0 @@ -1055,7 +1063,7 @@ async def team_info( for key in keys: try: key = key.model_dump() # noqa - except: + except Exception: # if using pydantic v1 key = key.dict() key.pop("token", None) @@ -1070,9 +1078,16 @@ async def team_info( for tm in team_memberships: returned_tm.append(LiteLLM_TeamMembership(**tm.model_dump())) + if isinstance(team_info, dict): + _team_info = LiteLLM_TeamTable(**team_info) + elif isinstance(team_info, BaseModel): + _team_info = LiteLLM_TeamTable(**team_info.model_dump()) + else: + _team_info = LiteLLM_TeamTable() + response_object = TeamInfoResponseObject( team_id=team_id, - team_info=team_info, + team_info=_team_info, keys=keys, team_memberships=returned_tm, ) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 51df7de87..1ae14de69 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -125,16 +125,7 @@ from litellm.proxy._types import * from litellm.proxy.analytics_endpoints.analytics_endpoints import ( router as analytics_router, ) -from litellm.proxy.auth.auth_checks import ( - allowed_routes_check, - common_checks, - get_actual_routes, - get_end_user_object, - get_org_object, - get_team_object, - get_user_object, - log_to_opentelemetry, -) +from litellm.proxy.auth.auth_checks import log_to_opentelemetry from litellm.proxy.auth.auth_utils import check_response_size_is_safe from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.auth.litellm_license import LicenseCheck @@ -260,6 +251,7 @@ from litellm.secret_managers.aws_secret_manager import ( load_aws_secret_manager, ) from litellm.secret_managers.google_kms import load_google_kms +from litellm.secret_managers.main import get_secret from litellm.types.llms.anthropic import ( AnthropicMessagesRequest, AnthropicResponse, @@ -484,7 +476,7 @@ general_settings: dict = {} callback_settings: dict = {} log_file = "api_log.json" worker_config = None -master_key = None +master_key: Optional[str] = None otel_logging = False prisma_client: Optional[PrismaClient] = None custom_db_client: Optional[DBClient] = None @@ -874,7 +866,9 @@ def error_tracking(): def _set_spend_logs_payload( - payload: dict, prisma_client: PrismaClient, spend_logs_url: Optional[str] = None + payload: Union[dict, SpendLogsPayload], + prisma_client: PrismaClient, + spend_logs_url: Optional[str] = None, ): if prisma_client is not None and spend_logs_url is not None: if isinstance(payload["startTime"], datetime): @@ -1341,6 +1335,9 @@ async def _run_background_health_check(): # make 1 deep copy of llm_model_list -> use this for all background health checks _llm_model_list = copy.deepcopy(llm_model_list) + if _llm_model_list is None: + return + while True: healthy_endpoints, unhealthy_endpoints = await perform_health_check( model_list=_llm_model_list, details=health_check_details @@ -1352,7 +1349,10 @@ async def _run_background_health_check(): health_check_results["healthy_count"] = len(healthy_endpoints) health_check_results["unhealthy_count"] = len(unhealthy_endpoints) - await asyncio.sleep(health_check_interval) + if health_check_interval is not None and isinstance( + health_check_interval, float + ): + await asyncio.sleep(health_check_interval) class ProxyConfig: @@ -1467,7 +1467,7 @@ class ProxyConfig: break for k, v in team_config.items(): if isinstance(v, str) and v.startswith("os.environ/"): - team_config[k] = litellm.get_secret(v) + team_config[k] = get_secret(v) return team_config def _init_cache( @@ -1513,6 +1513,9 @@ class ProxyConfig: config = get_file_contents_from_s3( bucket_name=bucket_name, object_key=object_key ) + + if config is None: + raise Exception("Unable to load config from given source.") else: # default to file config = await self.get_config(config_file_path=config_file_path) @@ -1528,9 +1531,7 @@ class ProxyConfig: environment_variables = config.get("environment_variables", None) if environment_variables: for key, value in environment_variables.items(): - os.environ[key] = str( - litellm.get_secret(secret_name=key, default_value=value) - ) + os.environ[key] = str(get_secret(secret_name=key, default_value=value)) # check if litellm_license in general_settings if "LITELLM_LICENSE" in environment_variables: @@ -1566,8 +1567,8 @@ class ProxyConfig: if ( cache_type == "redis" or cache_type == "redis-semantic" ) and len(cache_params.keys()) == 0: - cache_host = litellm.get_secret("REDIS_HOST", None) - cache_port = litellm.get_secret("REDIS_PORT", None) + cache_host = get_secret("REDIS_HOST", None) + cache_port = get_secret("REDIS_PORT", None) cache_password = None cache_params.update( { @@ -1577,8 +1578,8 @@ class ProxyConfig: } ) - if litellm.get_secret("REDIS_PASSWORD", None) is not None: - cache_password = litellm.get_secret("REDIS_PASSWORD", None) + if get_secret("REDIS_PASSWORD", None) is not None: + cache_password = get_secret("REDIS_PASSWORD", None) cache_params.update( { "password": cache_password, @@ -1617,7 +1618,7 @@ class ProxyConfig: # users can pass os.environ/ variables on the proxy - we should read them from the env for key, value in cache_params.items(): if type(value) is str and value.startswith("os.environ/"): - cache_params[key] = litellm.get_secret(value) + cache_params[key] = get_secret(value) ## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = `, _redis.py checks for REDIS specific environment variables self._init_cache(cache_params=cache_params) @@ -1738,7 +1739,7 @@ class ProxyConfig: if value is not None and isinstance(value, dict): for _k, _v in value.items(): if isinstance(_v, str) and _v.startswith("os.environ/"): - value[_k] = litellm.get_secret(_v) + value[_k] = get_secret(_v) litellm.upperbound_key_generate_params = ( LiteLLM_UpperboundKeyGenerateParams(**value) ) @@ -1812,15 +1813,15 @@ class ProxyConfig: database_url = general_settings.get("database_url", None) if database_url and database_url.startswith("os.environ/"): verbose_proxy_logger.debug("GOING INTO LITELLM.GET_SECRET!") - database_url = litellm.get_secret(database_url) + database_url = get_secret(database_url) verbose_proxy_logger.debug("RETRIEVED DB URL: %s", database_url) ### MASTER KEY ### master_key = general_settings.get( - "master_key", litellm.get_secret("LITELLM_MASTER_KEY", None) + "master_key", get_secret("LITELLM_MASTER_KEY", None) ) if master_key and master_key.startswith("os.environ/"): - master_key = litellm.get_secret(master_key) + master_key = get_secret(master_key) # type: ignore if not isinstance(master_key, str): raise Exception( "Master key must be a string. Current type - {}".format( @@ -1861,33 +1862,6 @@ class ProxyConfig: await initialize_pass_through_endpoints( pass_through_endpoints=general_settings["pass_through_endpoints"] ) - ## dynamodb - database_type = general_settings.get("database_type", None) - if database_type is not None and ( - database_type == "dynamo_db" or database_type == "dynamodb" - ): - database_args = general_settings.get("database_args", None) - ### LOAD FROM os.environ/ ### - for k, v in database_args.items(): - if isinstance(v, str) and v.startswith("os.environ/"): - database_args[k] = litellm.get_secret(v) - if isinstance(k, str) and k == "aws_web_identity_token": - value = database_args[k] - verbose_proxy_logger.debug( - f"Loading AWS Web Identity Token from file: {value}" - ) - if os.path.exists(value): - with open(value, "r") as file: - token_content = file.read() - database_args[k] = token_content - else: - verbose_proxy_logger.info( - f"DynamoDB Loading - {value} is not a valid file path" - ) - verbose_proxy_logger.debug("database_args: %s", database_args) - custom_db_client = DBClient( - custom_db_args=database_args, custom_db_type=database_type - ) ## ADMIN UI ACCESS ## ui_access_mode = general_settings.get( "ui_access_mode", "all" @@ -1951,7 +1925,7 @@ class ProxyConfig: ### LOAD FROM os.environ/ ### for k, v in model["litellm_params"].items(): if isinstance(v, str) and v.startswith("os.environ/"): - model["litellm_params"][k] = litellm.get_secret(v) + model["litellm_params"][k] = get_secret(v) print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa litellm_model_name = model["litellm_params"]["model"] litellm_model_api_base = model["litellm_params"].get("api_base", None) @@ -2005,7 +1979,10 @@ class ProxyConfig: ) # type:ignore # Guardrail settings - guardrails_v2 = config.get("guardrails", None) + guardrails_v2: Optional[dict] = None + + if config is not None: + guardrails_v2 = config.get("guardrails", None) if guardrails_v2: init_guardrails_v2( all_guardrails=guardrails_v2, config_file_path=config_file_path @@ -2074,7 +2051,7 @@ class ProxyConfig: ### LOAD FROM os.environ/ ### for k, v in model["litellm_params"].items(): if isinstance(v, str) and v.startswith("os.environ/"): - model["litellm_params"][k] = litellm.get_secret(v) + model["litellm_params"][k] = get_secret(v) ## check if they have model-id's ## model_id = model.get("model_info", {}).get("id", None) @@ -2234,7 +2211,8 @@ class ProxyConfig: for k, v in environment_variables.items(): try: decrypted_value = decrypt_value_helper(value=v) - os.environ[k] = decrypted_value + if decrypted_value is not None: + os.environ[k] = decrypted_value except Exception as e: verbose_proxy_logger.error( "Error setting env variable: %s - %s", k, str(e) @@ -2536,7 +2514,7 @@ async def async_assistants_data_generator( ) # chunk = chunk.model_dump_json(exclude_none=True) - async for c in chunk: + async for c in chunk: # type: ignore c = c.model_dump_json(exclude_none=True) try: yield f"data: {c}\n\n" @@ -2745,17 +2723,22 @@ async def startup_event(): ### LOAD MASTER KEY ### # check if master key set in environment - load from there - master_key = litellm.get_secret("LITELLM_MASTER_KEY", None) + master_key = get_secret("LITELLM_MASTER_KEY", None) # type: ignore # check if DATABASE_URL in environment - load from there if prisma_client is None: - prisma_setup(database_url=litellm.get_secret("DATABASE_URL", None)) + _db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore + prisma_setup(database_url=_db_url) ### LOAD CONFIG ### - worker_config = litellm.get_secret("WORKER_CONFIG") + worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore verbose_proxy_logger.debug("worker_config: %s", worker_config) # check if it's a valid file path - if os.path.isfile(worker_config): - if proxy_config.is_yaml(config_file_path=worker_config): + if worker_config is not None: + if ( + isinstance(worker_config, str) + and os.path.isfile(worker_config) + and proxy_config.is_yaml(config_file_path=worker_config) + ): ( llm_router, llm_model_list, @@ -2763,21 +2746,23 @@ async def startup_event(): ) = await proxy_config.load_config( router=llm_router, config_file_path=worker_config ) - else: + elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None and isinstance( + worker_config, str + ): + ( + llm_router, + llm_model_list, + general_settings, + ) = await proxy_config.load_config( + router=llm_router, config_file_path=worker_config + ) + elif isinstance(worker_config, dict): await initialize(**worker_config) - elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None: - ( - llm_router, - llm_model_list, - general_settings, - ) = await proxy_config.load_config( - router=llm_router, config_file_path=worker_config - ) - - else: - # if not, assume it's a json string - worker_config = json.loads(os.getenv("WORKER_CONFIG")) - await initialize(**worker_config) + else: + # if not, assume it's a json string + worker_config = json.loads(worker_config) + if isinstance(worker_config, dict): + await initialize(**worker_config) ## CHECK PREMIUM USER verbose_proxy_logger.debug( @@ -2825,7 +2810,7 @@ async def startup_event(): if general_settings.get("litellm_jwtauth", None) is not None: for k, v in general_settings["litellm_jwtauth"].items(): if isinstance(v, str) and v.startswith("os.environ/"): - general_settings["litellm_jwtauth"][k] = litellm.get_secret(v) + general_settings["litellm_jwtauth"][k] = get_secret(v) litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"]) else: litellm_jwtauth = LiteLLM_JWTAuth() @@ -2948,8 +2933,7 @@ async def startup_event(): ### ADD NEW MODELS ### store_model_in_db = ( - litellm.get_secret("STORE_MODEL_IN_DB", store_model_in_db) - or store_model_in_db + get_secret("STORE_MODEL_IN_DB", store_model_in_db) or store_model_in_db ) # type: ignore if store_model_in_db == True: scheduler.add_job( @@ -3498,7 +3482,7 @@ async def completion( ) ### CALL HOOKS ### - modify outgoing data response = await proxy_logging_obj.post_call_success_hook( - data=data, user_api_key_dict=user_api_key_dict, response=response + data=data, user_api_key_dict=user_api_key_dict, response=response # type: ignore ) fastapi_response.headers.update( @@ -4000,7 +3984,7 @@ async def audio_speech( request_data=data, ) return StreamingResponse( - generate(response), media_type="audio/mpeg", headers=custom_headers + generate(response), media_type="audio/mpeg", headers=custom_headers # type: ignore ) except Exception as e: @@ -4288,6 +4272,7 @@ async def create_assistant( API Reference docs - https://platform.openai.com/docs/api-reference/assistants/createAssistant """ global proxy_logging_obj + data = {} # ensure data always dict try: # Use orjson to parse JSON data, orjson speeds up requests significantly body = await request.body() @@ -7642,6 +7627,7 @@ async def model_group_info( ) model_groups: List[ModelGroupInfo] = [] + for model in all_models_str: _model_group_info = llm_router.get_model_group_info(model_group=model) @@ -8051,7 +8037,8 @@ async def google_login(request: Request): with microsoft_sso: return await microsoft_sso.get_login_redirect() elif generic_client_id is not None: - from fastapi_sso.sso.generic import DiscoveryDocument, create_provider + from fastapi_sso.sso.base import DiscoveryDocument + from fastapi_sso.sso.generic import create_provider generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ") @@ -8616,6 +8603,8 @@ async def auth_callback(request: Request): redirect_url += "sso/callback" else: redirect_url += "/sso/callback" + + result = None if google_client_id is not None: from fastapi_sso.sso.google import GoogleSSO @@ -8662,7 +8651,8 @@ async def auth_callback(request: Request): result = await microsoft_sso.verify_and_process(request) elif generic_client_id is not None: # make generic sso provider - from fastapi_sso.sso.generic import DiscoveryDocument, OpenID, create_provider + from fastapi_sso.sso.base import DiscoveryDocument, OpenID + from fastapi_sso.sso.generic import create_provider generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ") @@ -8766,8 +8756,8 @@ async def auth_callback(request: Request): verbose_proxy_logger.debug("generic result: %s", result) # User is Authe'd in - generate key for the UI to access Proxy - user_email = getattr(result, "email", None) - user_id = getattr(result, "id", None) + user_email: Optional[str] = getattr(result, "email", None) + user_id: Optional[str] = getattr(result, "id", None) if result is not None else None if user_email is not None and os.getenv("ALLOWED_EMAIL_DOMAINS") is not None: email_domain = user_email.split("@")[1] @@ -8783,12 +8773,12 @@ async def auth_callback(request: Request): ) # generic client id - if generic_client_id is not None: + if generic_client_id is not None and result is not None: user_id = getattr(result, "id", None) user_email = getattr(result, "email", None) - user_role = getattr(result, generic_user_role_attribute_name, None) + user_role = getattr(result, generic_user_role_attribute_name, None) # type: ignore - if user_id is None: + if user_id is None and result is not None: _first_name = getattr(result, "first_name", "") or "" _last_name = getattr(result, "last_name", "") or "" user_id = _first_name + _last_name @@ -8811,54 +8801,45 @@ async def auth_callback(request: Request): "spend": 0, "team_id": "litellm-dashboard", } - user_defined_values: SSOUserDefinedValues = { - "models": user_id_models, - "user_id": user_id, - "user_email": user_email, - "max_budget": max_internal_user_budget, - "user_role": None, - "budget_duration": internal_user_budget_duration, - } + user_defined_values: Optional[SSOUserDefinedValues] = None + if user_id is not None: + user_defined_values = SSOUserDefinedValues( + models=user_id_models, + user_id=user_id, + user_email=user_email, + max_budget=max_internal_user_budget, + user_role=None, + budget_duration=internal_user_budget_duration, + ) + _user_id_from_sso = user_id + user_role = None try: - user_role = None if prisma_client is not None: user_info = await prisma_client.get_data(user_id=user_id, table_name="user") verbose_proxy_logger.debug( f"user_info: {user_info}; litellm.default_user_params: {litellm.default_user_params}" ) - if user_info is not None: - user_defined_values = { - "models": getattr(user_info, "models", user_id_models), - "user_id": getattr(user_info, "user_id", user_id), - "user_email": getattr(user_info, "user_id", user_email), - "user_role": getattr(user_info, "user_role", None), - "max_budget": getattr( - user_info, "max_budget", max_internal_user_budget - ), - "budget_duration": getattr( - user_info, "budget_duration", internal_user_budget_duration - ), - } - user_role = getattr(user_info, "user_role", None) + if user_info is None: + ## check if user-email in db ## + user_info = await prisma_client.db.litellm_usertable.find_first( + where={"user_email": user_email} + ) - ## check if user-email in db ## - user_info = await prisma_client.db.litellm_usertable.find_first( - where={"user_email": user_email} - ) - if user_info is not None: - user_defined_values = { - "models": getattr(user_info, "models", user_id_models), - "user_id": user_id, - "user_email": getattr(user_info, "user_id", user_email), - "user_role": getattr(user_info, "user_role", None), - "max_budget": getattr( + if user_info is not None and user_id is not None: + user_defined_values = SSOUserDefinedValues( + models=getattr(user_info, "models", user_id_models), + user_id=user_id, + user_email=getattr(user_info, "user_email", user_email), + user_role=getattr(user_info, "user_role", None), + max_budget=getattr( user_info, "max_budget", max_internal_user_budget ), - "budget_duration": getattr( + budget_duration=getattr( user_info, "budget_duration", internal_user_budget_duration ), - } + ) + user_role = getattr(user_info, "user_role", None) # update id @@ -8886,6 +8867,11 @@ async def auth_callback(request: Request): except Exception as e: pass + if user_defined_values is None: + raise Exception( + "Unable to map user identity to known values. 'user_defined_values' is None. File an issue - https://github.com/BerriAI/litellm/issues" + ) + is_internal_user = False if ( user_defined_values["user_role"] is not None @@ -8960,7 +8946,8 @@ async def auth_callback(request: Request): master_key, algorithm="HS256", ) - litellm_dashboard_ui += "?userID=" + user_id + if user_id is not None and isinstance(user_id, str): + litellm_dashboard_ui += "?userID=" + user_id redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303) redirect_response.set_cookie(key="token", value=jwt_token) return redirect_response @@ -9023,6 +9010,7 @@ async def new_invitation( "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, } # type: ignore ) + return response except Exception as e: if "Foreign key constraint failed on the field" in str(e): raise HTTPException( @@ -9031,7 +9019,7 @@ async def new_invitation( "error": "User id does not exist in 'LiteLLM_UserTable'. Fix this by creating user via `/user/new`." }, ) - return response + raise HTTPException(status_code=500, detail={"error": str(e)}) @router.get( @@ -9951,44 +9939,46 @@ async def get_routes(): """ routes = [] for route in app.routes: - route_info = { - "path": getattr(route, "path", None), - "methods": getattr(route, "methods", None), - "name": getattr(route, "name", None), - "endpoint": ( - getattr(route, "endpoint", None).__name__ - if getattr(route, "endpoint", None) - else None - ), - } - routes.append(route_info) + endpoint_route = getattr(route, "endpoint", None) + if endpoint_route is not None: + route_info = { + "path": getattr(route, "path", None), + "methods": getattr(route, "methods", None), + "name": getattr(route, "name", None), + "endpoint": ( + endpoint_route.__name__ + if getattr(route, "endpoint", None) + else None + ), + } + routes.append(route_info) return {"routes": routes} #### TEST ENDPOINTS #### -@router.get( - "/token/generate", - dependencies=[Depends(user_api_key_auth)], - include_in_schema=False, -) -async def token_generate(): - """ - Test endpoint. Admin-only access. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc. - """ - # Initialize AuthJWTSSO with your OpenID Provider configuration - from fastapi_sso import AuthJWTSSO +# @router.get( +# "/token/generate", +# dependencies=[Depends(user_api_key_auth)], +# include_in_schema=False, +# ) +# async def token_generate(): +# """ +# Test endpoint. Admin-only access. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc. +# """ +# # Initialize AuthJWTSSO with your OpenID Provider configuration +# from fastapi_sso import AuthJWTSSO - auth_jwt_sso = AuthJWTSSO( - issuer=os.getenv("OPENID_BASE_URL"), - client_id=os.getenv("OPENID_CLIENT_ID"), - client_secret=os.getenv("OPENID_CLIENT_SECRET"), - scopes=["litellm_proxy_admin"], - ) +# auth_jwt_sso = AuthJWTSSO( +# issuer=os.getenv("OPENID_BASE_URL"), +# client_id=os.getenv("OPENID_CLIENT_ID"), +# client_secret=os.getenv("OPENID_CLIENT_SECRET"), +# scopes=["litellm_proxy_admin"], +# ) - token = auth_jwt_sso.create_access_token() +# token = auth_jwt_sso.create_access_token() - return {"token": token} +# return {"token": token} @router.on_event("shutdown") @@ -10013,7 +10003,8 @@ async def shutdown_event(): # flush langfuse logs on shutdow from litellm.utils import langFuseLogger - langFuseLogger.Langfuse.flush() + if langFuseLogger is not None: + langFuseLogger.Langfuse.flush() except: # [DO NOT BLOCK shutdown events for this] pass diff --git a/litellm/router.py b/litellm/router.py index 1628a633a..03b9fc8e4 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -92,6 +92,7 @@ from litellm.types.router import ( RetryPolicy, RouterErrors, RouterGeneralSettings, + RouterModelGroupAliasItem, RouterRateLimitError, RouterRateLimitErrorBasic, updateDeployment, @@ -105,6 +106,7 @@ from litellm.utils import ( calculate_max_parallel_requests, create_proxy_transport_and_mounts, get_llm_provider, + get_secret, get_utc_datetime, ) @@ -156,7 +158,9 @@ class Router: fallbacks: List = [], context_window_fallbacks: List = [], content_policy_fallbacks: List = [], - model_group_alias: Optional[dict] = {}, + model_group_alias: Optional[ + Dict[str, Union[str, RouterModelGroupAliasItem]] + ] = {}, enable_pre_call_checks: bool = False, enable_tag_filtering: bool = False, retry_after: int = 0, # min time to wait before retrying a failed request @@ -331,7 +335,8 @@ class Router: self.set_model_list(model_list) self.healthy_deployments: List = self.model_list # type: ignore for m in model_list: - self.deployment_latency_map[m["litellm_params"]["model"]] = 0 + if "model" in m["litellm_params"]: + self.deployment_latency_map[m["litellm_params"]["model"]] = 0 else: self.model_list: List = ( [] @@ -398,7 +403,7 @@ class Router: self.previous_models: List = ( [] ) # list to store failed calls (passed in as metadata to next call) - self.model_group_alias: dict = ( + self.model_group_alias: Dict[str, Union[str, RouterModelGroupAliasItem]] = ( model_group_alias or {} ) # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group @@ -1179,6 +1184,7 @@ class Router: raise e def _image_generation(self, prompt: str, model: str, **kwargs): + model_name = "" try: verbose_router_logger.debug( f"Inside _image_generation()- model: {model}; kwargs: {kwargs}" @@ -1269,6 +1275,7 @@ class Router: raise e async def _aimage_generation(self, prompt: str, model: str, **kwargs): + model_name = "" try: verbose_router_logger.debug( f"Inside _image_generation()- model: {model}; kwargs: {kwargs}" @@ -1401,6 +1408,7 @@ class Router: raise e async def _atranscription(self, file: FileTypes, model: str, **kwargs): + model_name = model try: verbose_router_logger.debug( f"Inside _atranscription()- model: {model}; kwargs: {kwargs}" @@ -1781,6 +1789,7 @@ class Router: is_async: Optional[bool] = False, **kwargs, ): + messages = [{"role": "user", "content": prompt}] try: kwargs["model"] = model kwargs["prompt"] = prompt @@ -1789,7 +1798,6 @@ class Router: timeout = kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) - messages = [{"role": "user", "content": prompt}] # pick the one that is available (lowest TPM/RPM) deployment = self.get_available_deployment( model=model, @@ -2534,13 +2542,13 @@ class Router: try: # Update kwargs with the current model name or any other model-specific adjustments ## SET CUSTOM PROVIDER TO SELECTED DEPLOYMENT ## - _, custom_llm_provider, _, _ = get_llm_provider( + _, custom_llm_provider, _, _ = get_llm_provider( # type: ignore model=model_name["litellm_params"]["model"] ) new_kwargs = copy.deepcopy(kwargs) new_kwargs.pop("custom_llm_provider", None) return await litellm.aretrieve_batch( - custom_llm_provider=custom_llm_provider, **new_kwargs + custom_llm_provider=custom_llm_provider, **new_kwargs # type: ignore ) except Exception as e: receieved_exceptions.append(e) @@ -2616,13 +2624,13 @@ class Router: for result in results: if result is not None: ## check batch id - if final_results["first_id"] is None: - final_results["first_id"] = result.first_id - final_results["last_id"] = result.last_id + if final_results["first_id"] is None and hasattr(result, "first_id"): + final_results["first_id"] = getattr(result, "first_id") + final_results["last_id"] = getattr(result, "last_id") final_results["data"].extend(result.data) # type: ignore ## check 'has_more' - if result.has_more is True: + if getattr(result, "has_more", False) is True: final_results["has_more"] = True return final_results @@ -2874,8 +2882,12 @@ class Router: verbose_router_logger.debug(f"Traceback{traceback.format_exc()}") original_exception = e fallback_model_group = None - original_model_group = kwargs.get("model") + original_model_group: Optional[str] = kwargs.get("model") # type: ignore fallback_failure_exception_str = "" + + if original_model_group is None: + raise e + try: verbose_router_logger.debug("Trying to fallback b/w models") if isinstance(e, litellm.ContextWindowExceededError): @@ -2972,7 +2984,7 @@ class Router: f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" ) if hasattr(original_exception, "message"): - original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" + original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" # type: ignore raise original_exception response = await run_async_fallback( @@ -2996,12 +3008,12 @@ class Router: if hasattr(original_exception, "message"): # add the available fallbacks to the exception - original_exception.message += "\nReceived Model Group={}\nAvailable Model Group Fallbacks={}".format( + original_exception.message += "\nReceived Model Group={}\nAvailable Model Group Fallbacks={}".format( # type: ignore model_group, fallback_model_group, ) if len(fallback_failure_exception_str) > 0: - original_exception.message += ( + original_exception.message += ( # type: ignore "\nError doing the fallback: {}".format( fallback_failure_exception_str ) @@ -3117,9 +3129,15 @@ class Router: ## LOGGING kwargs = self.log_retry(kwargs=kwargs, e=e) remaining_retries = num_retries - current_attempt - _healthy_deployments, _ = await self._async_get_healthy_deployments( - model=kwargs.get("model"), - ) + _model: Optional[str] = kwargs.get("model") # type: ignore + if _model is not None: + _healthy_deployments, _ = ( + await self._async_get_healthy_deployments( + model=_model, + ) + ) + else: + _healthy_deployments = [] _timeout = self._time_to_sleep_before_retry( e=original_exception, remaining_retries=remaining_retries, @@ -3129,8 +3147,8 @@ class Router: await asyncio.sleep(_timeout) if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES: - original_exception.max_retries = num_retries - original_exception.num_retries = current_attempt + setattr(original_exception, "max_retries", num_retries) + setattr(original_exception, "num_retries", current_attempt) raise original_exception @@ -3225,8 +3243,12 @@ class Router: return response except Exception as e: original_exception = e - original_model_group = kwargs.get("model") + original_model_group: Optional[str] = kwargs.get("model") verbose_router_logger.debug(f"An exception occurs {original_exception}") + + if original_model_group is None: + raise e + try: verbose_router_logger.debug( f"Trying to fallback b/w models. Initial model group: {model_group}" @@ -3336,10 +3358,10 @@ class Router: return 0 response_headers: Optional[httpx.Headers] = None - if hasattr(e, "response") and hasattr(e.response, "headers"): - response_headers = e.response.headers + if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore + response_headers = e.response.headers # type: ignore elif hasattr(e, "litellm_response_headers"): - response_headers = e.litellm_response_headers + response_headers = e.litellm_response_headers # type: ignore if response_headers is not None: timeout = litellm._calculate_retry_after( @@ -3398,9 +3420,13 @@ class Router: except Exception as e: current_attempt = None original_exception = e + _model: Optional[str] = kwargs.get("model") # type: ignore + + if _model is None: + raise e # re-raise error, if model can't be determined for loadbalancing ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR _healthy_deployments, _all_deployments = self._get_healthy_deployments( - model=kwargs.get("model"), + model=_model, ) # raises an exception if this error should not be retries @@ -3438,8 +3464,12 @@ class Router: except Exception as e: ## LOGGING kwargs = self.log_retry(kwargs=kwargs, e=e) + _model: Optional[str] = kwargs.get("model") # type: ignore + + if _model is None: + raise e # re-raise error, if model can't be determined for loadbalancing _healthy_deployments, _ = self._get_healthy_deployments( - model=kwargs.get("model"), + model=_model, ) remaining_retries = num_retries - current_attempt _timeout = self._time_to_sleep_before_retry( @@ -4055,7 +4085,7 @@ class Router: if isinstance(_litellm_params, dict): for k, v in _litellm_params.items(): if isinstance(v, str) and v.startswith("os.environ/"): - _litellm_params[k] = litellm.get_secret(v) + _litellm_params[k] = get_secret(v) _model_info: dict = model.pop("model_info", {}) @@ -4392,7 +4422,6 @@ class Router: - ModelGroupInfo if able to construct a model group - None if error constructing model group info """ - model_group_info: Optional[ModelGroupInfo] = None total_tpm: Optional[int] = None @@ -4557,12 +4586,23 @@ class Router: Returns: - ModelGroupInfo if able to construct a model group - - None if error constructing model group info + - None if error constructing model group info or hidden model group """ ## Check if model group alias if model_group in self.model_group_alias: + item = self.model_group_alias[model_group] + if isinstance(item, str): + _router_model_group = item + elif isinstance(item, dict): + if item["hidden"] is True: + return None + else: + _router_model_group = item["model"] + else: + return None + return self._set_model_group_info( - model_group=self.model_group_alias[model_group], + model_group=_router_model_group, user_facing_model_group_name=model_group, ) @@ -4666,7 +4706,14 @@ class Router: Includes model_group_alias models too. """ - return self.model_names + list(self.model_group_alias.keys()) + model_list = self.get_model_list() + if model_list is None: + return [] + + model_names = [] + for m in model_list: + model_names.append(m["model_name"]) + return model_names def get_model_list( self, model_name: Optional[str] = None @@ -4678,9 +4725,21 @@ class Router: returned_models: List[DeploymentTypedDict] = [] for model_alias, model_value in self.model_group_alias.items(): + + if isinstance(model_value, str): + _router_model_name: str = model_value + elif isinstance(model_value, dict): + _model_value = RouterModelGroupAliasItem(**model_value) # type: ignore + if _model_value["hidden"] is True: + continue + else: + _router_model_name = _model_value["model"] + else: + continue + returned_models.extend( self._get_all_deployments( - model_name=model_value, model_alias=model_alias + model_name=_router_model_name, model_alias=model_alias ) ) @@ -5078,10 +5137,11 @@ class Router: ) if model in self.model_group_alias: - verbose_router_logger.debug( - f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}" - ) - model = self.model_group_alias[model] + _item = self.model_group_alias[model] + if isinstance(_item, str): + model = _item + else: + model = _item["model"] if model not in self.model_names: # check if provider/ specific wildcard routing @@ -5124,7 +5184,9 @@ class Router: m for m in self.model_list if m["litellm_params"]["model"] == model ] - litellm.print_verbose(f"initial list of deployments: {healthy_deployments}") + verbose_router_logger.debug( + f"initial list of deployments: {healthy_deployments}" + ) if len(healthy_deployments) == 0: raise ValueError( @@ -5208,7 +5270,7 @@ class Router: ) # check if user wants to do tag based routing - healthy_deployments = await get_deployments_for_tag( + healthy_deployments = await get_deployments_for_tag( # type: ignore llm_router_instance=self, request_kwargs=request_kwargs, healthy_deployments=healthy_deployments, @@ -5241,7 +5303,7 @@ class Router: input=input, ) ) - if ( + elif ( self.routing_strategy == "cost-based-routing" and self.lowestcost_logger is not None ): @@ -5326,6 +5388,8 @@ class Router: ############## No RPM/TPM passed, we do a random pick ################# item = random.choice(healthy_deployments) return item or item[0] + else: + deployment = None if deployment is None: verbose_router_logger.info( f"get_available_deployment for model: {model}, No deployment available" @@ -5515,6 +5579,9 @@ class Router: messages=messages, input=input, ) + else: + deployment = None + if deployment is None: verbose_router_logger.info( f"get_available_deployment for model: {model}, No deployment available" @@ -5690,6 +5757,9 @@ class Router: def _initialize_alerting(self): from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting + if self.alerting_config is None: + return + router_alerting_config: AlertingConfig = self.alerting_config _slack_alerting_logger = SlackAlerting( @@ -5700,7 +5770,7 @@ class Router: self.slack_alerting_logger = _slack_alerting_logger - litellm.callbacks.append(_slack_alerting_logger) + litellm.callbacks.append(_slack_alerting_logger) # type: ignore litellm.success_callback.append( _slack_alerting_logger.response_taking_too_long_callback ) diff --git a/litellm/router_strategy/tag_based_routing.py b/litellm/router_strategy/tag_based_routing.py index 78bc5e4f9..deda1bd77 100644 --- a/litellm/router_strategy/tag_based_routing.py +++ b/litellm/router_strategy/tag_based_routing.py @@ -21,8 +21,8 @@ else: async def get_deployments_for_tag( llm_router_instance: LitellmRouter, + healthy_deployments: Union[List[Any], Dict[Any, Any]], request_kwargs: Optional[Dict[Any, Any]] = None, - healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None, ): if llm_router_instance.enable_tag_filtering is not True: return healthy_deployments diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 91ed7ea4a..1a2dab3f6 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -2828,3 +2828,44 @@ def test_gemini_function_call_parameter_in_messages_2(): ] }, ] + + +@pytest.mark.parametrize( + "base_model, metadata", + [ + (None, {"model_info": {"base_model": "vertex_ai/gemini-1.5-pro"}}), + ("vertex_ai/gemini-1.5-pro", None), + ], +) +def test_gemini_finetuned_endpoint(base_model, metadata): + litellm.set_verbose = True + load_vertex_ai_credentials() + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + # Set up the messages + messages = [ + {"role": "system", "content": """Use search for most queries."""}, + {"role": "user", "content": """search for weather in boston (use `search`)"""}, + ] + + client = HTTPHandler(concurrent_limit=1) + + with patch.object(client, "post", new=MagicMock()) as mock_client: + try: + response = completion( + model="vertex_ai/4965075652664360960", + messages=messages, + tool_choice="auto", + client=client, + metadata=metadata, + base_model=base_model, + ) + except Exception as e: + print(e) + + print(mock_client.call_args.kwargs) + + mock_client.assert_called() + assert mock_client.call_args.kwargs["url"].endswith( + "endpoints/4965075652664360960:generateContent" + ) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 457f645aa..6dd1bad5f 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -311,6 +311,8 @@ async def test_completion_predibase(): pass except litellm.ServiceUnavailableError as e: pass + except litellm.InternalServerError: + pass except Exception as e: pytest.fail(f"Error occurred: {e}") diff --git a/litellm/tests/test_health_check.py b/litellm/tests/test_health_check.py index 75f40541a..71c6c4217 100644 --- a/litellm/tests/test_health_check.py +++ b/litellm/tests/test_health_check.py @@ -133,3 +133,21 @@ async def test_fireworks_health_check(): assert response == {} return response + + +@pytest.mark.asyncio +async def test_cohere_rerank_health_check(): + response = await litellm.ahealth_check( + model_params={ + "model": "cohere/rerank-english-v3.0", + "query": "Hey, how's it going", + "documents": ["my sample text"], + "api_key": os.getenv("COHERE_API_KEY"), + }, + mode="rerank", + prompt="Hey, how's it going", + ) + + assert "error" not in response + + print(response) diff --git a/litellm/tests/test_image_generation.py b/litellm/tests/test_image_generation.py index 25088da80..8eb749e6a 100644 --- a/litellm/tests/test_image_generation.py +++ b/litellm/tests/test_image_generation.py @@ -50,6 +50,7 @@ def test_image_generation_openai(): ], # False ) # @pytest.mark.asyncio +@pytest.mark.flaky(retries=3, delay=1) async def test_image_generation_azure(sync_mode): try: if sync_mode: diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 09f12442a..3a81cc27a 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -533,7 +533,7 @@ def test_call_with_user_over_budget(prisma_client): # use generated key to auth in result = await user_api_key_auth(request=request, api_key=bearer_token) print("result from user auth with new key", result) - pytest.fail(f"This should have failed!. They key crossed it's budget") + pytest.fail("This should have failed!. They key crossed it's budget") asyncio.run(test()) except Exception as e: @@ -1755,7 +1755,7 @@ def test_call_with_key_over_model_budget(prisma_client): # use generated key to auth in result = await user_api_key_auth(request=request, api_key=bearer_token) print("result from user auth with new key", result) - pytest.fail(f"This should have failed!. They key crossed it's budget") + pytest.fail("This should have failed!. They key crossed it's budget") asyncio.run(test()) except Exception as e: diff --git a/litellm/tests/test_optional_params.py b/litellm/tests/test_optional_params.py index 3e7d1e5e5..1250dbe24 100644 --- a/litellm/tests/test_optional_params.py +++ b/litellm/tests/test_optional_params.py @@ -589,3 +589,14 @@ def test_parse_additional_properties_json_schema(model, provider, expectedAddPro elif provider == "openai": schema = optional_params["response_format"]["json_schema"]["schema"] assert ("additionalProperties" in schema) == expectedAddProp + + +def test_o1_model_params(): + optional_params = get_optional_params( + model="o1-preview-2024-09-12", + custom_llm_provider="openai", + seed=10, + user="John", + ) + assert optional_params["seed"] == 10 + assert optional_params["user"] == "John" diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 8dc82a595..ed179d3e2 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -762,7 +762,7 @@ async def test_team_update_redis(): ) as mock_client: await _cache_team_object( team_id="1234", - team_table=LiteLLM_TeamTableCachedObj(), + team_table=LiteLLM_TeamTableCachedObj(team_id="1234"), user_api_key_cache=DualCache(), proxy_logging_obj=proxy_logging_obj, ) @@ -776,7 +776,7 @@ async def test_get_team_redis(client_no_auth): Tests if get_team_object gets value from redis cache, if set """ from litellm.caching import DualCache, RedisCache - from litellm.proxy.auth.auth_checks import _cache_team_object, get_team_object + from litellm.proxy.auth.auth_checks import get_team_object proxy_logging_obj: ProxyLogging = getattr( litellm.proxy.proxy_server, "proxy_logging_obj" @@ -917,7 +917,9 @@ async def test_create_team_member_add(prisma_client, new_member_method): ) litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = team_mock_client - team_mock_client.update = AsyncMock(return_value=LiteLLM_TeamTableCachedObj()) + team_mock_client.update = AsyncMock( + return_value=LiteLLM_TeamTableCachedObj(team_id="1234") + ) await team_member_add( data=team_member_add_request, @@ -1095,7 +1097,9 @@ async def test_create_team_member_add_team_admin( ) litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = team_mock_client - team_mock_client.update = AsyncMock(return_value=LiteLLM_TeamTableCachedObj()) + team_mock_client.update = AsyncMock( + return_value=LiteLLM_TeamTableCachedObj(team_id="1234") + ) try: await team_member_add( @@ -1434,8 +1438,9 @@ async def test_gemini_pass_through_endpoint(): print(resp.body) +@pytest.mark.parametrize("hidden", [True, False]) @pytest.mark.asyncio -async def test_proxy_model_group_alias_checks(prisma_client): +async def test_proxy_model_group_alias_checks(prisma_client, hidden): """ Check if model group alias is returned on @@ -1465,7 +1470,7 @@ async def test_proxy_model_group_alias_checks(prisma_client): model_alias = "gpt-4" router = litellm.Router( model_list=_model_list, - model_group_alias={model_alias: "gpt-3.5-turbo"}, + model_group_alias={model_alias: {"model": "gpt-3.5-turbo", "hidden": hidden}}, ) setattr(litellm.proxy.proxy_server, "llm_router", router) setattr(litellm.proxy.proxy_server, "llm_model_list", _model_list) @@ -1477,7 +1482,10 @@ async def test_proxy_model_group_alias_checks(prisma_client): user_api_key_dict=UserAPIKeyAuth(models=[]), ) - assert len(resp) == 2 + if hidden: + assert len(resp["data"]) == 1 + else: + assert len(resp["data"]) == 2 print(resp) resp = await model_info_v1( @@ -1489,7 +1497,10 @@ async def test_proxy_model_group_alias_checks(prisma_client): if model_alias == item["model_name"]: is_model_alias_in_list = True - assert is_model_alias_in_list + if hidden: + assert is_model_alias_in_list is False + else: + assert is_model_alias_in_list resp = await model_group_info( user_api_key_dict=UserAPIKeyAuth(models=[]), @@ -1500,4 +1511,7 @@ async def test_proxy_model_group_alias_checks(prisma_client): if model_alias == item.model_group: is_model_alias_in_list = True - assert is_model_alias_in_list + if hidden: + assert is_model_alias_in_list is False + else: + assert is_model_alias_in_list, f"models: {models}" diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 05d9f9f76..df34fb758 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -2481,3 +2481,31 @@ async def test_router_batch_endpoints(provider): model="my-custom-name", custom_llm_provider=provider, limit=2 ) print("list_batches=", list_batches) + + +@pytest.mark.parametrize("hidden", [True, False]) +def test_model_group_alias(hidden): + _model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "gpt-3.5-turbo"}, + }, + {"model_name": "gpt-4", "litellm_params": {"model": "gpt-4"}}, + ] + router = Router( + model_list=_model_list, + model_group_alias={ + "gpt-4.5-turbo": {"model": "gpt-3.5-turbo", "hidden": hidden} + }, + ) + + models = router.get_model_list() + + model_names = router.get_model_names() + + if hidden: + assert len(models) == len(_model_list) + assert len(model_names) == len(_model_list) + else: + assert len(models) == len(_model_list) + 1 + assert len(model_names) == len(_model_list) + 1 diff --git a/litellm/types/router.py b/litellm/types/router.py index cb4273a6a..8c8c6a3aa 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -582,3 +582,8 @@ class RouterRateLimitError(ValueError): self.cooldown_list = cooldown_list _message = f"{RouterErrors.no_deployments_available.value}, Try again in {cooldown_time} seconds. Passed model={model}. pre-call-checks={enable_pre_call_checks}, cooldown_list={cooldown_list}" super().__init__(_message) + + +class RouterModelGroupAliasItem(TypedDict): + model: str + hidden: bool # if 'True', don't return on `.get_model_list` diff --git a/litellm/utils.py b/litellm/utils.py index 6b7b94a70..d3e757ae8 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -76,6 +76,7 @@ from litellm.types.llms.openai import ( ChatCompletionNamedToolChoiceParam, ChatCompletionToolParam, ) +from litellm.types.utils import FileTypes # type: ignore from litellm.types.utils import ( CallTypes, ChatCompletionDeltaToolCall, @@ -84,7 +85,6 @@ from litellm.types.utils import ( Delta, Embedding, EmbeddingResponse, - FileTypes, ImageResponse, Message, ModelInfo, @@ -2339,6 +2339,7 @@ def get_litellm_params( text_completion=None, azure_ad_token_provider=None, user_continue_message=None, + base_model=None, ): litellm_params = { "acompletion": acompletion, @@ -2365,6 +2366,8 @@ def get_litellm_params( "text_completion": text_completion, "azure_ad_token_provider": azure_ad_token_provider, "user_continue_message": user_continue_message, + "base_model": base_model + or _get_base_model_from_litellm_call_metadata(metadata=metadata), } return litellm_params @@ -6063,11 +6066,11 @@ def _calculate_retry_after( max_retries: int, response_headers: Optional[httpx.Headers] = None, min_timeout: int = 0, -): +) -> Union[float, int]: retry_after = _get_retry_after_from_exception_header(response_headers) # If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says. - if 0 < retry_after <= 60: + if retry_after is not None and 0 < retry_after <= 60: return retry_after initial_retry_delay = 0.5 @@ -10962,6 +10965,22 @@ def get_logging_id(start_time, response_obj): return None +def _get_base_model_from_litellm_call_metadata( + metadata: Optional[dict], +) -> Optional[str]: + if metadata is None: + return None + + if metadata is not None: + model_info = metadata.get("model_info", {}) + + if model_info is not None: + base_model = model_info.get("base_model", None) + if base_model is not None: + return base_model + return None + + def _get_base_model_from_metadata(model_call_details=None): if model_call_details is None: return None @@ -10970,13 +10989,7 @@ def _get_base_model_from_metadata(model_call_details=None): if litellm_params is not None: metadata = litellm_params.get("metadata", {}) - if metadata is not None: - model_info = metadata.get("model_info", {}) - - if model_info is not None: - base_model = model_info.get("base_model", None) - if base_model is not None: - return base_model + return _get_base_model_from_litellm_call_metadata(metadata=metadata) return None diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 000000000..051dcb9fc --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,6 @@ +{ + "ignore": [], + "exclude": ["**/node_modules", "**/__pycache__", "litellm/tests", "litellm/main.py", "litellm/utils.py"], + "reportMissingImports": false +} + \ No newline at end of file