diff --git a/README.md b/README.md index 5d3efe3559..41bb756199 100644 --- a/README.md +++ b/README.md @@ -271,6 +271,7 @@ curl 'http://0.0.0.0:4000/key/generate' \ | [voyage ai](https://docs.litellm.ai/docs/providers/voyage) | | | | | ✅ | | | [xinference [Xorbits Inference]](https://docs.litellm.ai/docs/providers/xinference) | | | | | ✅ | | | [FriendliAI](https://docs.litellm.ai/docs/providers/friendliai) | ✅ | ✅ | ✅ | ✅ | | | +| [Galadriel](https://docs.litellm.ai/docs/providers/galadriel) | ✅ | ✅ | ✅ | ✅ | | | [**Read the Docs**](https://docs.litellm.ai/docs/) diff --git a/docs/my-website/docs/completion/document_understanding.md b/docs/my-website/docs/completion/document_understanding.md new file mode 100644 index 0000000000..6719169aef --- /dev/null +++ b/docs/my-website/docs/completion/document_understanding.md @@ -0,0 +1,202 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# Using PDF Input + +How to send / receieve pdf's (other document types) to a `/chat/completions` endpoint + +Works for: +- Vertex AI models (Gemini + Anthropic) +- Bedrock Models +- Anthropic API Models + +## Quick Start + +### url + + + + +```python +from litellm.utils import supports_pdf_input, completion + +# set aws credentials +os.environ["AWS_ACCESS_KEY_ID"] = "" +os.environ["AWS_SECRET_ACCESS_KEY"] = "" +os.environ["AWS_REGION_NAME"] = "" + + +# pdf url +image_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf" + +# model +model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0" + +image_content = [ + {"type": "text", "text": "What's this file about?"}, + { + "type": "image_url", + "image_url": image_url, # OR {"url": image_url} + }, +] + + +if not supports_pdf_input(model, None): + print("Model does not support image input") + +response = completion( + model=model, + messages=[{"role": "user", "content": image_content}], +) +assert response is not None +``` + + + +1. Setup config.yaml + +```yaml +model_list: + - model_name: bedrock-model + litellm_params: + model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0 + aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID + aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY + aws_region_name: os.environ/AWS_REGION_NAME +``` + +2. Start the proxy + +```bash +litellm --config /path/to/config.yaml +``` + +3. Test it! + +```bash +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-d '{ + "model": "bedrock-model", + "messages": [ + {"role": "user", "content": {"type": "text", "text": "What's this file about?"}}, + { + "type": "image_url", + "image_url": "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf", + } + ] +}' +``` + + + +### base64 + + + + +```python +from litellm.utils import supports_pdf_input, completion + +# set aws credentials +os.environ["AWS_ACCESS_KEY_ID"] = "" +os.environ["AWS_SECRET_ACCESS_KEY"] = "" +os.environ["AWS_REGION_NAME"] = "" + + +# pdf url +image_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf" +response = requests.get(url) +file_data = response.content + +encoded_file = base64.b64encode(file_data).decode("utf-8") +base64_url = f"data:application/pdf;base64,{encoded_file}" + +# model +model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0" + +image_content = [ + {"type": "text", "text": "What's this file about?"}, + { + "type": "image_url", + "image_url": base64_url, # OR {"url": base64_url} + }, +] + + +if not supports_pdf_input(model, None): + print("Model does not support image input") + +response = completion( + model=model, + messages=[{"role": "user", "content": image_content}], +) +assert response is not None +``` + + + +## Checking if a model supports pdf input + + + + +Use `litellm.supports_pdf_input(model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0")` -> returns `True` if model can accept pdf input + +```python +assert litellm.supports_pdf_input(model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0") == True +``` + + + + +1. Define bedrock models on config.yaml + +```yaml +model_list: + - model_name: bedrock-model # model group name + litellm_params: + model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0 + aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID + aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY + aws_region_name: os.environ/AWS_REGION_NAME + model_info: # OPTIONAL - set manually + supports_pdf_input: True +``` + +2. Run proxy server + +```bash +litellm --config config.yaml +``` + +3. Call `/model_group/info` to check if a model supports `pdf` input + +```shell +curl -X 'GET' \ + 'http://localhost:4000/model_group/info' \ + -H 'accept: application/json' \ + -H 'x-api-key: sk-1234' +``` + +Expected Response + +```json +{ + "data": [ + { + "model_group": "bedrock-model", + "providers": ["bedrock"], + "max_input_tokens": 128000, + "max_output_tokens": 16384, + "mode": "chat", + ..., + "supports_pdf_input": true, # 👈 supports_pdf_input is true + } + ] +} +``` + + + diff --git a/docs/my-website/docs/providers/bedrock.md b/docs/my-website/docs/providers/bedrock.md index 579353d652..741f8fcf29 100644 --- a/docs/my-website/docs/providers/bedrock.md +++ b/docs/my-website/docs/providers/bedrock.md @@ -706,6 +706,37 @@ print(response) +## Set 'converse' / 'invoke' route + +LiteLLM defaults to the `invoke` route. LiteLLM uses the `converse` route for Bedrock models that support it. + +To explicitly set the route, do `bedrock/converse/` or `bedrock/invoke/`. + + +E.g. + + + + +```python +from litellm import completion + +completion(model="bedrock/converse/us.amazon.nova-pro-v1:0") +``` + + + + +```yaml +model_list: + - model_name: bedrock-model + litellm_params: + model: bedrock/converse/us.amazon.nova-pro-v1:0 +``` + + + + ## Alternate user/assistant messages Use `user_continue_message` to add a default user message, for cases (e.g. Autogen) where the client might not follow alternating user/assistant messages starting and ending with a user message. @@ -745,6 +776,174 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \ }' ``` +## Usage - PDF / Document Understanding + +LiteLLM supports Document Understanding for Bedrock models - [AWS Bedrock Docs](https://docs.aws.amazon.com/nova/latest/userguide/modalities-document.html). + +### url + + + + +```python +from litellm.utils import supports_pdf_input, completion + +# set aws credentials +os.environ["AWS_ACCESS_KEY_ID"] = "" +os.environ["AWS_SECRET_ACCESS_KEY"] = "" +os.environ["AWS_REGION_NAME"] = "" + + +# pdf url +image_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf" + +# model +model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0" + +image_content = [ + {"type": "text", "text": "What's this file about?"}, + { + "type": "image_url", + "image_url": image_url, # OR {"url": image_url} + }, +] + + +if not supports_pdf_input(model, None): + print("Model does not support image input") + +response = completion( + model=model, + messages=[{"role": "user", "content": image_content}], +) +assert response is not None +``` + + + +1. Setup config.yaml + +```yaml +model_list: + - model_name: bedrock-model + litellm_params: + model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0 + aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID + aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY + aws_region_name: os.environ/AWS_REGION_NAME +``` + +2. Start the proxy + +```bash +litellm --config /path/to/config.yaml +``` + +3. Test it! + +```bash +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-d '{ + "model": "bedrock-model", + "messages": [ + {"role": "user", "content": {"type": "text", "text": "What's this file about?"}}, + { + "type": "image_url", + "image_url": "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf", + } + ] +}' +``` + + + +### base64 + + + + +```python +from litellm.utils import supports_pdf_input, completion + +# set aws credentials +os.environ["AWS_ACCESS_KEY_ID"] = "" +os.environ["AWS_SECRET_ACCESS_KEY"] = "" +os.environ["AWS_REGION_NAME"] = "" + + +# pdf url +image_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf" +response = requests.get(url) +file_data = response.content + +encoded_file = base64.b64encode(file_data).decode("utf-8") +base64_url = f"data:application/pdf;base64,{encoded_file}" + +# model +model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0" + +image_content = [ + {"type": "text", "text": "What's this file about?"}, + { + "type": "image_url", + "image_url": base64_url, # OR {"url": base64_url} + }, +] + + +if not supports_pdf_input(model, None): + print("Model does not support image input") + +response = completion( + model=model, + messages=[{"role": "user", "content": image_content}], +) +assert response is not None +``` + + + +1. Setup config.yaml + +```yaml +model_list: + - model_name: bedrock-model + litellm_params: + model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0 + aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID + aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY + aws_region_name: os.environ/AWS_REGION_NAME +``` + +2. Start the proxy + +```bash +litellm --config /path/to/config.yaml +``` + +3. Test it! + +```bash +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-d '{ + "model": "bedrock-model", + "messages": [ + {"role": "user", "content": {"type": "text", "text": "What's this file about?"}}, + { + "type": "image_url", + "image_url": "data:application/pdf;base64,{b64_encoded_file}", + } + ] +}' +``` + + + + ## Boto3 - Authentication ### Passing credentials as parameters - Completion() diff --git a/docs/my-website/docs/providers/galadriel.md b/docs/my-website/docs/providers/galadriel.md new file mode 100644 index 0000000000..73f1ec8e76 --- /dev/null +++ b/docs/my-website/docs/providers/galadriel.md @@ -0,0 +1,63 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# Galadriel +https://docs.galadriel.com/api-reference/chat-completion-API + +LiteLLM supports all models on Galadriel. + +## API Key +```python +import os +os.environ['GALADRIEL_API_KEY'] = "your-api-key" +``` + +## Sample Usage +```python +from litellm import completion +import os + +os.environ['GALADRIEL_API_KEY'] = "" +response = completion( + model="galadriel/llama3.1", + messages=[ + {"role": "user", "content": "hello from litellm"} + ], +) +print(response) +``` + +## Sample Usage - Streaming +```python +from litellm import completion +import os + +os.environ['GALADRIEL_API_KEY'] = "" +response = completion( + model="galadriel/llama3.1", + messages=[ + {"role": "user", "content": "hello from litellm"} + ], + stream=True +) + +for chunk in response: + print(chunk) +``` + + +## Supported Models +### Serverless Endpoints +We support ALL Galadriel AI models, just set `galadriel/` as a prefix when sending completion requests + +We support both the complete model name and the simplified name match. + +You can specify the model name either with the full name or with a simplified version e.g. `llama3.1:70b` + +| Model Name | Simplified Name | Function Call | +| -------------------------------------------------------- | -------------------------------- | ------------------------------------------------------- | +| neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 | llama3.1 or llama3.1:8b | `completion(model="galadriel/llama3.1", messages)` | +| neuralmagic/Meta-Llama-3.1-70B-Instruct-quantized.w4a16 | llama3.1:70b | `completion(model="galadriel/llama3.1:70b", messages)` | +| neuralmagic/Meta-Llama-3.1-405B-Instruct-quantized.w4a16 | llama3.1:405b | `completion(model="galadriel/llama3.1:405b", messages)` | +| neuralmagic/Mistral-Nemo-Instruct-2407-quantized.w4a16 | mistral-nemo or mistral-nemo:12b | `completion(model="galadriel/mistral-nemo", messages)` | + diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index c762a0716c..875c691e37 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -134,28 +134,9 @@ general_settings: | content_policy_fallbacks | array of objects | Fallbacks to use when a ContentPolicyViolationError is encountered. [Further docs](./reliability#content-policy-fallbacks) | | context_window_fallbacks | array of objects | Fallbacks to use when a ContextWindowExceededError is encountered. [Further docs](./reliability#context-window-fallbacks) | | cache | boolean | If true, enables caching. [Further docs](./caching) | -| cache_params | object | Parameters for the cache. [Further docs](./caching) | -| cache_params.type | string | The type of cache to initialize. Can be one of ["local", "redis", "redis-semantic", "s3", "disk", "qdrant-semantic"]. Defaults to "redis". [Furher docs](./caching) | -| cache_params.host | string | The host address for the Redis cache. Required if type is "redis". | -| cache_params.port | integer | The port number for the Redis cache. Required if type is "redis". | -| cache_params.password | string | The password for the Redis cache. Required if type is "redis". | -| cache_params.namespace | string | The namespace for the Redis cache. | -| cache_params.redis_startup_nodes | array of objects | Redis Cluster Settings. [Further docs](./caching) | -| cache_params.service_name | string | Redis Sentinel Settings. [Further docs](./caching) | -| cache_params.sentinel_nodes | array of arrays | Redis Sentinel Settings. [Further docs](./caching) | -| cache_params.ttl | integer | The time (in seconds) to store entries in cache. | -| cache_params.qdrant_semantic_cache_embedding_model | string | The embedding model to use for qdrant semantic cache. | -| cache_params.qdrant_collection_name | string | The name of the collection to use for qdrant semantic cache. | -| cache_params.qdrant_quantization_config | string | The quantization configuration for the qdrant semantic cache. | -| cache_params.similarity_threshold | float | The similarity threshold for the semantic cache. | -| cache_params.s3_bucket_name | string | The name of the S3 bucket to use for the semantic cache. | -| cache_params.s3_region_name | string | The region name for the S3 bucket. | -| cache_params.s3_aws_access_key_id | string | The AWS access key ID for the S3 bucket. | -| cache_params.s3_aws_secret_access_key | string | The AWS secret access key for the S3 bucket. | -| cache_params.s3_endpoint_url | string | Optional - The endpoint URL for the S3 bucket. | -| cache_params.supported_call_types | array of strings | The types of calls to cache. [Further docs](./caching) | -| cache_params.mode | string | The mode of the cache. [Further docs](./caching) | +| cache_params | object | Parameters for the cache. [Further docs](./caching#supported-cache_params-on-proxy-configyaml) | | disable_end_user_cost_tracking | boolean | If true, turns off end user cost tracking on prometheus metrics + litellm spend logs table on proxy. | +| disable_end_user_cost_tracking_prometheus_only | boolean | If true, turns off end user cost tracking on prometheus metrics only. | | key_generation_settings | object | Restricts who can generate keys. [Further docs](./virtual_keys.md#restricting-key-generation) | ### general_settings - Reference diff --git a/docs/my-website/docs/proxy/enterprise.md b/docs/my-website/docs/proxy/enterprise.md index a41d02bc25..f4d649fa2d 100644 --- a/docs/my-website/docs/proxy/enterprise.md +++ b/docs/my-website/docs/proxy/enterprise.md @@ -507,6 +507,11 @@ curl -X GET "http://0.0.0.0:4000/spend/logs?request_id= + + + **Step 1** Define all Params you want to enforce on config.yaml This means `["user"]` and `["metadata]["generation_name"]` are required in all LLM Requests to LiteLLM @@ -518,8 +523,21 @@ general_settings: - user - metadata.generation_name ``` + -Start LiteLLM Proxy + + +```bash +curl -L -X POST 'http://0.0.0.0:4000/key/generate' \ +-H 'Authorization: Bearer sk-1234' \ +-H 'Content-Type: application/json' \ +-d '{ + "enforced_params": ["user", "metadata.generation_name"] +}' +``` + + + **Step 2 Verify if this works** diff --git a/docs/my-website/docs/proxy/virtual_keys.md b/docs/my-website/docs/proxy/virtual_keys.md index 5bbb6b2a00..2107698f32 100644 --- a/docs/my-website/docs/proxy/virtual_keys.md +++ b/docs/my-website/docs/proxy/virtual_keys.md @@ -828,8 +828,18 @@ litellm_settings: #### Spec ```python +key_generation_settings: Optional[StandardKeyGenerationConfig] = None +``` + +#### Types + +```python +class StandardKeyGenerationConfig(TypedDict, total=False): + team_key_generation: TeamUIKeyGenerationConfig + personal_key_generation: PersonalUIKeyGenerationConfig + class TeamUIKeyGenerationConfig(TypedDict): - allowed_team_member_roles: List[str] + allowed_team_member_roles: List[str] # either 'user' or 'admin' required_params: List[str] # require params on `/key/generate` to be set if a team key (team_id in request) is being generated @@ -838,11 +848,6 @@ class PersonalUIKeyGenerationConfig(TypedDict): required_params: List[str] # require params on `/key/generate` to be set if a personal key (no team_id in request) is being generated -class StandardKeyGenerationConfig(TypedDict, total=False): - team_key_generation: TeamUIKeyGenerationConfig - personal_key_generation: PersonalUIKeyGenerationConfig - - class LitellmUserRoles(str, enum.Enum): """ Admin Roles: diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index e6a028d831..7e1b4bb257 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -177,6 +177,7 @@ const sidebars = { "providers/ollama", "providers/perplexity", "providers/friendliai", + "providers/galadriel", "providers/groq", "providers/github", "providers/deepseek", @@ -210,6 +211,7 @@ const sidebars = { "completion/provider_specific_params", "guides/finetuned_models", "completion/audio", + "completion/document_understanding", "completion/vision", "completion/json_mode", "completion/prompt_caching", diff --git a/litellm/__init__.py b/litellm/__init__.py index d80caaf8f6..244b7e6e85 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -288,6 +288,7 @@ max_internal_user_budget: Optional[float] = None internal_user_budget_duration: Optional[str] = None max_end_user_budget: Optional[float] = None disable_end_user_cost_tracking: Optional[bool] = None +disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None #### REQUEST PRIORITIZATION #### priority_reservation: Optional[Dict[str, float]] = None #### RELIABILITY #### @@ -385,6 +386,31 @@ organization = None project = None config_path = None vertex_ai_safety_settings: Optional[dict] = None +BEDROCK_CONVERSE_MODELS = [ + "anthropic.claude-3-5-haiku-20241022-v1:0", + "anthropic.claude-3-5-sonnet-20241022-v2:0", + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "anthropic.claude-3-opus-20240229-v1:0", + "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-3-haiku-20240307-v1:0", + "anthropic.claude-v2", + "anthropic.claude-v2:1", + "anthropic.claude-v1", + "anthropic.claude-instant-v1", + "ai21.jamba-instruct-v1:0", + "meta.llama3-70b-instruct-v1:0", + "meta.llama3-8b-instruct-v1:0", + "meta.llama3-1-8b-instruct-v1:0", + "meta.llama3-1-70b-instruct-v1:0", + "meta.llama3-1-405b-instruct-v1:0", + "meta.llama3-70b-instruct-v1:0", + "mistral.mistral-large-2407-v1:0", + "meta.llama3-2-1b-instruct-v1:0", + "meta.llama3-2-3b-instruct-v1:0", + "meta.llama3-2-11b-instruct-v1:0", + "meta.llama3-2-90b-instruct-v1:0", + "meta.llama3-2-405b-instruct-v1:0", +] ####### COMPLETION MODELS ################### open_ai_chat_completion_models: List = [] open_ai_text_completion_models: List = [] @@ -412,6 +438,7 @@ ai21_chat_models: List = [] nlp_cloud_models: List = [] aleph_alpha_models: List = [] bedrock_models: List = [] +bedrock_converse_models: List = BEDROCK_CONVERSE_MODELS fireworks_ai_models: List = [] fireworks_ai_embedding_models: List = [] deepinfra_models: List = [] @@ -431,6 +458,7 @@ groq_models: List = [] azure_models: List = [] anyscale_models: List = [] cerebras_models: List = [] +galadriel_models: List = [] def add_known_models(): @@ -491,6 +519,8 @@ def add_known_models(): aleph_alpha_models.append(key) elif value.get("litellm_provider") == "bedrock": bedrock_models.append(key) + elif value.get("litellm_provider") == "bedrock_converse": + bedrock_converse_models.append(key) elif value.get("litellm_provider") == "deepinfra": deepinfra_models.append(key) elif value.get("litellm_provider") == "perplexity": @@ -535,6 +565,8 @@ def add_known_models(): anyscale_models.append(key) elif value.get("litellm_provider") == "cerebras": cerebras_models.append(key) + elif value.get("litellm_provider") == "galadriel": + galadriel_models.append(key) add_known_models() @@ -554,6 +586,7 @@ openai_compatible_endpoints: List = [ "inference.friendli.ai/v1", "api.sambanova.ai/v1", "api.x.ai/v1", + "api.galadriel.ai/v1", ] # this is maintained for Exception Mapping @@ -581,6 +614,7 @@ openai_compatible_providers: List = [ "litellm_proxy", "hosted_vllm", "lm_studio", + "galadriel", ] openai_text_completion_compatible_providers: List = ( [ # providers that support `/v1/completions` @@ -794,6 +828,7 @@ model_list = ( + azure_models + anyscale_models + cerebras_models + + galadriel_models ) @@ -860,6 +895,7 @@ class LlmProviders(str, Enum): LITELLM_PROXY = "litellm_proxy" HOSTED_VLLM = "hosted_vllm" LM_STUDIO = "lm_studio" + GALADRIEL = "galadriel" provider_list: List[Union[LlmProviders, str]] = list(LlmProviders) @@ -882,7 +918,7 @@ models_by_provider: dict = { + vertex_vision_models + vertex_language_models, "ai21": ai21_models, - "bedrock": bedrock_models, + "bedrock": bedrock_models + bedrock_converse_models, "petals": petals_models, "ollama": ollama_models, "deepinfra": deepinfra_models, @@ -908,6 +944,7 @@ models_by_provider: dict = { "azure": azure_models, "anyscale": anyscale_models, "cerebras": cerebras_models, + "galadriel": galadriel_models, } # mapping for those models which have larger equivalents @@ -1067,9 +1104,6 @@ from .llms.bedrock.chat.invoke_handler import ( AmazonConverseConfig, bedrock_tool_name_mappings, ) -from .llms.bedrock.chat.converse_handler import ( - BEDROCK_CONVERSE_MODELS, -) from .llms.bedrock.common_utils import ( AmazonTitanConfig, AmazonAI21Config, diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index 1460a1d7f0..b87f245240 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -365,7 +365,9 @@ class PrometheusLogger(CustomLogger): model = kwargs.get("model", "") litellm_params = kwargs.get("litellm_params", {}) or {} _metadata = litellm_params.get("metadata", {}) - end_user_id = get_end_user_id_for_cost_tracking(litellm_params) + end_user_id = get_end_user_id_for_cost_tracking( + litellm_params, service_type="prometheus" + ) user_id = standard_logging_payload["metadata"]["user_api_key_user_id"] user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"] user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"] @@ -668,7 +670,9 @@ class PrometheusLogger(CustomLogger): "standard_logging_object", {} ) litellm_params = kwargs.get("litellm_params", {}) or {} - end_user_id = get_end_user_id_for_cost_tracking(litellm_params) + end_user_id = get_end_user_id_for_cost_tracking( + litellm_params, service_type="prometheus" + ) user_id = standard_logging_payload["metadata"]["user_api_key_user_id"] user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"] user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"] diff --git a/litellm/litellm_core_utils/get_llm_provider_logic.py b/litellm/litellm_core_utils/get_llm_provider_logic.py index 71eaaead0c..fea931491f 100644 --- a/litellm/litellm_core_utils/get_llm_provider_logic.py +++ b/litellm/litellm_core_utils/get_llm_provider_logic.py @@ -177,6 +177,9 @@ def get_llm_provider( # noqa: PLR0915 dynamic_api_key = get_secret_str( "FRIENDLIAI_API_KEY" ) or get_secret("FRIENDLI_TOKEN") + elif endpoint == "api.galadriel.com/v1": + custom_llm_provider = "galadriel" + dynamic_api_key = get_secret_str("GALADRIEL_API_KEY") if api_base is not None and not isinstance(api_base, str): raise Exception( @@ -526,6 +529,11 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915 or get_secret_str("FRIENDLIAI_API_KEY") or get_secret_str("FRIENDLI_TOKEN") ) + elif custom_llm_provider == "galadriel": + api_base = ( + api_base or get_secret("GALADRIEL_API_BASE") or "https://api.galadriel.com/v1" + ) # type: ignore + dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY") if api_base is not None and not isinstance(api_base, str): raise Exception("api base needs to be a string. api_base={}".format(api_base)) if dynamic_api_key is not None and not isinstance(dynamic_api_key, str): diff --git a/litellm/llms/bedrock/chat/converse_handler.py b/litellm/llms/bedrock/chat/converse_handler.py index e47ba4f421..743e596f8e 100644 --- a/litellm/llms/bedrock/chat/converse_handler.py +++ b/litellm/llms/bedrock/chat/converse_handler.py @@ -18,32 +18,6 @@ from ...base_aws_llm import BaseAWSLLM from ..common_utils import BedrockError from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call -BEDROCK_CONVERSE_MODELS = [ - "anthropic.claude-3-5-haiku-20241022-v1:0", - "anthropic.claude-3-5-sonnet-20241022-v2:0", - "anthropic.claude-3-5-sonnet-20240620-v1:0", - "anthropic.claude-3-opus-20240229-v1:0", - "anthropic.claude-3-sonnet-20240229-v1:0", - "anthropic.claude-3-haiku-20240307-v1:0", - "anthropic.claude-v2", - "anthropic.claude-v2:1", - "anthropic.claude-v1", - "anthropic.claude-instant-v1", - "ai21.jamba-instruct-v1:0", - "meta.llama3-70b-instruct-v1:0", - "meta.llama3-8b-instruct-v1:0", - "meta.llama3-1-8b-instruct-v1:0", - "meta.llama3-1-70b-instruct-v1:0", - "meta.llama3-1-405b-instruct-v1:0", - "meta.llama3-70b-instruct-v1:0", - "mistral.mistral-large-2407-v1:0", - "meta.llama3-2-1b-instruct-v1:0", - "meta.llama3-2-3b-instruct-v1:0", - "meta.llama3-2-11b-instruct-v1:0", - "meta.llama3-2-90b-instruct-v1:0", - "meta.llama3-2-405b-instruct-v1:0", -] - def make_sync_call( client: Optional[HTTPHandler], @@ -53,6 +27,8 @@ def make_sync_call( model: str, messages: list, logging_obj, + json_mode: Optional[bool] = False, + fake_stream: bool = False, ): if client is None: client = _get_httpx_client() # Create a new client if none provided @@ -61,13 +37,13 @@ def make_sync_call( api_base, headers=headers, data=data, - stream=True if "ai21" not in api_base else False, + stream=not fake_stream, ) if response.status_code != 200: raise BedrockError(status_code=response.status_code, message=response.read()) - if "ai21" in api_base: + if fake_stream: model_response: ( ModelResponse ) = litellm.AmazonConverseConfig()._transform_response( @@ -83,7 +59,9 @@ def make_sync_call( print_verbose=litellm.print_verbose, encoding=litellm.encoding, ) # type: ignore - completion_stream: Any = MockResponseIterator(model_response=model_response) + completion_stream: Any = MockResponseIterator( + model_response=model_response, json_mode=json_mode + ) else: decoder = AWSEventStreamDecoder(model=model) completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) @@ -130,6 +108,8 @@ class BedrockConverseLLM(BaseAWSLLM): logger_fn=None, headers={}, client: Optional[AsyncHTTPHandler] = None, + fake_stream: bool = False, + json_mode: Optional[bool] = False, ) -> CustomStreamWrapper: completion_stream = await make_call( @@ -140,6 +120,8 @@ class BedrockConverseLLM(BaseAWSLLM): model=model, messages=messages, logging_obj=logging_obj, + fake_stream=fake_stream, + json_mode=json_mode, ) streaming_response = CustomStreamWrapper( completion_stream=completion_stream, @@ -231,12 +213,15 @@ class BedrockConverseLLM(BaseAWSLLM): ## SETUP ## stream = optional_params.pop("stream", None) modelId = optional_params.pop("model_id", None) + fake_stream = optional_params.pop("fake_stream", False) + json_mode = optional_params.get("json_mode", False) if modelId is not None: modelId = self.encode_model_id(model_id=modelId) else: modelId = model - provider = model.split(".")[0] + if stream is True and "ai21" in modelId: + fake_stream = True ## CREDENTIALS ## # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them @@ -290,7 +275,7 @@ class BedrockConverseLLM(BaseAWSLLM): aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, aws_region_name=aws_region_name, ) - if (stream is not None and stream is True) and provider != "ai21": + if (stream is not None and stream is True) and not fake_stream: endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream" proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream" else: @@ -355,6 +340,8 @@ class BedrockConverseLLM(BaseAWSLLM): headers=prepped.headers, timeout=timeout, client=client, + json_mode=json_mode, + fake_stream=fake_stream, ) # type: ignore ### ASYNC COMPLETION return self.async_completion( @@ -398,6 +385,8 @@ class BedrockConverseLLM(BaseAWSLLM): model=model, messages=messages, logging_obj=logging_obj, + json_mode=json_mode, + fake_stream=fake_stream, ) streaming_response = CustomStreamWrapper( completion_stream=completion_stream, diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index 23ee97a47e..e8d7e80b74 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -134,6 +134,43 @@ class AmazonConverseConfig: def get_supported_image_types(self) -> List[str]: return ["png", "jpeg", "gif", "webp"] + def get_supported_document_types(self) -> List[str]: + return ["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"] + + def _create_json_tool_call_for_response_format( + self, + json_schema: Optional[dict] = None, + schema_name: str = "json_tool_call", + ) -> ChatCompletionToolParam: + """ + Handles creating a tool call for getting responses in JSON format. + + Args: + json_schema (Optional[dict]): The JSON schema the response should be in + + Returns: + AnthropicMessagesTool: The tool call to send to Anthropic API to get responses in JSON format + """ + + if json_schema is None: + # Anthropic raises a 400 BadRequest error if properties is passed as None + # see usage with additionalProperties (Example 5) https://github.com/anthropics/anthropic-cookbook/blob/main/tool_use/extracting_structured_json.ipynb + _input_schema = { + "type": "object", + "additionalProperties": True, + "properties": {}, + } + else: + _input_schema = json_schema + + _tool = ChatCompletionToolParam( + type="function", + function=ChatCompletionToolParamFunctionChunk( + name=schema_name, parameters=_input_schema + ), + ) + return _tool + def map_openai_params( self, model: str, @@ -160,31 +197,20 @@ class AmazonConverseConfig: - You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool - Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective. """ - if json_schema is not None: - _tool_choice = self.map_tool_choice_values( - model=model, tool_choice="required", drop_params=drop_params # type: ignore + _tool_choice = {"name": schema_name, "type": "tool"} + _tool = self._create_json_tool_call_for_response_format( + json_schema=json_schema, + schema_name=schema_name if schema_name != "" else "json_tool_call", + ) + optional_params["tools"] = [_tool] + optional_params["tool_choice"] = ToolChoiceValuesBlock( + tool=SpecificToolChoiceBlock( + name=schema_name if schema_name != "" else "json_tool_call" ) - - _tool = ChatCompletionToolParam( - type="function", - function=ChatCompletionToolParamFunctionChunk( - name=schema_name, parameters=json_schema - ), - ) - - optional_params["tools"] = [_tool] - optional_params["tool_choice"] = _tool_choice - optional_params["json_mode"] = True - else: - if litellm.drop_params is True or drop_params is True: - pass - else: - raise litellm.utils.UnsupportedParamsError( - message="Bedrock doesn't support response_format={}. To drop it from the call, set `litellm.drop_params = True.".format( - value - ), - status_code=400, - ) + ) + optional_params["json_mode"] = True + if non_default_params.get("stream", False) is True: + optional_params["fake_stream"] = True if param == "max_tokens" or param == "max_completion_tokens": optional_params["maxTokens"] = value if param == "stream": @@ -330,7 +356,6 @@ class AmazonConverseConfig: print_verbose, encoding, ) -> Union[ModelResponse, CustomStreamWrapper]: - ## LOGGING if logging_obj is not None: logging_obj.post_call( @@ -468,6 +493,9 @@ class AmazonConverseConfig: AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1" """ + if model.startswith("converse/"): + model = model.split("/")[1] + potential_region = model.split(".", 1)[0] if potential_region in self._supported_cross_region_inference_region(): return model.split(".", 1)[1] diff --git a/litellm/llms/bedrock/chat/invoke_handler.py b/litellm/llms/bedrock/chat/invoke_handler.py index 7805f74dc3..3a72cf8057 100644 --- a/litellm/llms/bedrock/chat/invoke_handler.py +++ b/litellm/llms/bedrock/chat/invoke_handler.py @@ -182,6 +182,8 @@ async def make_call( model: str, messages: list, logging_obj, + fake_stream: bool = False, + json_mode: Optional[bool] = False, ): try: if client is None: @@ -193,13 +195,13 @@ async def make_call( api_base, headers=headers, data=data, - stream=True if "ai21" not in api_base else False, + stream=not fake_stream, ) if response.status_code != 200: raise BedrockError(status_code=response.status_code, message=response.text) - if "ai21" in api_base: + if fake_stream: model_response: ( ModelResponse ) = litellm.AmazonConverseConfig()._transform_response( @@ -215,7 +217,9 @@ async def make_call( print_verbose=litellm.print_verbose, encoding=litellm.encoding, ) # type: ignore - completion_stream: Any = MockResponseIterator(model_response=model_response) + completion_stream: Any = MockResponseIterator( + model_response=model_response, json_mode=json_mode + ) else: decoder = AWSEventStreamDecoder(model=model) completion_stream = decoder.aiter_bytes( @@ -1028,6 +1032,7 @@ class BedrockLLM(BaseAWSLLM): model=model, messages=messages, logging_obj=logging_obj, + fake_stream=True if "ai21" in api_base else False, ), model=model, custom_llm_provider="bedrock", @@ -1271,21 +1276,58 @@ class AWSEventStreamDecoder: class MockResponseIterator: # for returning ai21 streaming responses - def __init__(self, model_response): + def __init__(self, model_response, json_mode: Optional[bool] = False): self.model_response = model_response + self.json_mode = json_mode self.is_done = False # Sync iterator def __iter__(self): return self - def _chunk_parser(self, chunk_data: ModelResponse) -> GChunk: + def _handle_json_mode_chunk( + self, text: str, tool_calls: Optional[List[ChatCompletionToolCallChunk]] + ) -> Tuple[str, Optional[ChatCompletionToolCallChunk]]: + """ + If JSON mode is enabled, convert the tool call to a message. + Bedrock returns the JSON schema as part of the tool call + OpenAI returns the JSON schema as part of the content, this handles placing it in the content + + Args: + text: str + tool_use: Optional[ChatCompletionToolCallChunk] + Returns: + Tuple[str, Optional[ChatCompletionToolCallChunk]] + + text: The text to use in the content + tool_use: The ChatCompletionToolCallChunk to use in the chunk response + """ + tool_use: Optional[ChatCompletionToolCallChunk] = None + if self.json_mode is True and tool_calls is not None: + message = litellm.AnthropicConfig()._convert_tool_response_to_message( + tool_calls=tool_calls + ) + if message is not None: + text = message.content or "" + tool_use = None + elif tool_calls is not None and len(tool_calls) > 0: + tool_use = tool_calls[0] + return text, tool_use + + def _chunk_parser(self, chunk_data: ModelResponse) -> GChunk: try: chunk_usage: litellm.Usage = getattr(chunk_data, "usage") + text = chunk_data.choices[0].message.content or "" # type: ignore + tool_use = None + if self.json_mode is True: + text, tool_use = self._handle_json_mode_chunk( + text=text, + tool_calls=chunk_data.choices[0].message.tool_calls, # type: ignore + ) processed_chunk = GChunk( - text=chunk_data.choices[0].message.content or "", # type: ignore - tool_use=None, + text=text, + tool_use=tool_use, is_finished=True, finish_reason=map_finish_reason( finish_reason=chunk_data.choices[0].finish_reason or "" @@ -1298,8 +1340,8 @@ class MockResponseIterator: # for returning ai21 streaming responses index=0, ) return processed_chunk - except Exception: - raise ValueError(f"Failed to decode chunk: {chunk_data}") + except Exception as e: + raise ValueError(f"Failed to decode chunk: {chunk_data}. Error: {e}") def __next__(self): if self.is_done: diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index bfd35ca475..2f55bb7bac 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -2162,8 +2162,9 @@ def stringify_json_tool_call_content(messages: List) -> List: ###### AMAZON BEDROCK ####### from litellm.types.llms.bedrock import ContentBlock as BedrockContentBlock +from litellm.types.llms.bedrock import DocumentBlock as BedrockDocumentBlock from litellm.types.llms.bedrock import ImageBlock as BedrockImageBlock -from litellm.types.llms.bedrock import ImageSourceBlock as BedrockImageSourceBlock +from litellm.types.llms.bedrock import SourceBlock as BedrockSourceBlock from litellm.types.llms.bedrock import ToolBlock as BedrockToolBlock from litellm.types.llms.bedrock import ( ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock, @@ -2210,7 +2211,9 @@ def get_image_details(image_url) -> Tuple[str, str]: raise e -def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock: +def _process_bedrock_converse_image_block( + image_url: str, +) -> BedrockContentBlock: if "base64" in image_url: # Case 1: Images with base64 encoding import base64 @@ -2228,12 +2231,17 @@ def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock: else: mime_type = "image/jpeg" image_format = "jpeg" - _blob = BedrockImageSourceBlock(bytes=img_without_base_64) + _blob = BedrockSourceBlock(bytes=img_without_base_64) supported_image_formats = ( litellm.AmazonConverseConfig().get_supported_image_types() ) + supported_document_types = ( + litellm.AmazonConverseConfig().get_supported_document_types() + ) if image_format in supported_image_formats: - return BedrockImageBlock(source=_blob, format=image_format) # type: ignore + return BedrockContentBlock(image=BedrockImageBlock(source=_blob, format=image_format)) # type: ignore + elif image_format in supported_document_types: + return BedrockContentBlock(document=BedrockDocumentBlock(source=_blob, format=image_format, name="DocumentPDFmessages_{}".format(str(uuid.uuid4())))) # type: ignore else: # Handle the case when the image format is not supported raise ValueError( @@ -2244,12 +2252,17 @@ def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock: elif "https:/" in image_url: # Case 2: Images with direct links image_bytes, image_format = get_image_details(image_url) - _blob = BedrockImageSourceBlock(bytes=image_bytes) + _blob = BedrockSourceBlock(bytes=image_bytes) supported_image_formats = ( litellm.AmazonConverseConfig().get_supported_image_types() ) + supported_document_types = ( + litellm.AmazonConverseConfig().get_supported_document_types() + ) if image_format in supported_image_formats: - return BedrockImageBlock(source=_blob, format=image_format) # type: ignore + return BedrockContentBlock(image=BedrockImageBlock(source=_blob, format=image_format)) # type: ignore + elif image_format in supported_document_types: + return BedrockContentBlock(document=BedrockDocumentBlock(source=_blob, format=image_format, name="DocumentPDFmessages_{}".format(str(uuid.uuid4())))) # type: ignore else: # Handle the case when the image format is not supported raise ValueError( @@ -2464,11 +2477,14 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915 _part = BedrockContentBlock(text=element["text"]) _parts.append(_part) elif element["type"] == "image_url": - image_url = element["image_url"]["url"] + if isinstance(element["image_url"], dict): + image_url = element["image_url"]["url"] + else: + image_url = element["image_url"] _part = _process_bedrock_converse_image_block( # type: ignore image_url=image_url ) - _parts.append(BedrockContentBlock(image=_part)) # type: ignore + _parts.append(_part) # type: ignore user_content.extend(_parts) else: _part = BedrockContentBlock(text=messages[msg_i]["content"]) @@ -2539,13 +2555,14 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915 assistants_part = BedrockContentBlock(text=element["text"]) assistants_parts.append(assistants_part) elif element["type"] == "image_url": - image_url = element["image_url"]["url"] + if isinstance(element["image_url"], dict): + image_url = element["image_url"]["url"] + else: + image_url = element["image_url"] assistants_part = _process_bedrock_converse_image_block( # type: ignore image_url=image_url ) - assistants_parts.append( - BedrockContentBlock(image=assistants_part) # type: ignore - ) + assistants_parts.append(assistants_part) assistant_content.extend(assistants_parts) elif messages[msg_i].get("content", None) is not None and isinstance( messages[msg_i]["content"], str diff --git a/litellm/main.py b/litellm/main.py index a32e8b6c05..d17c73c91e 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2603,7 +2603,10 @@ def completion( # type: ignore # noqa: PLR0915 base_model = litellm.AmazonConverseConfig()._get_base_model(model) - if base_model in litellm.BEDROCK_CONVERSE_MODELS: + if base_model in litellm.bedrock_converse_models or model.startswith( + "converse/" + ): + model = model.replace("converse/", "") response = bedrock_converse_chat_completion.completion( model=model, messages=messages, @@ -2622,6 +2625,7 @@ def completion( # type: ignore # noqa: PLR0915 api_base=api_base, ) else: + model = model.replace("invoke/", "") response = bedrock_chat_completion.completion( model=model, messages=messages, diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index ac22871bcc..4d89091c4d 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -4795,6 +4795,42 @@ "mode": "chat", "supports_function_calling": true }, + "amazon.nova-micro-v1:0": { + "max_tokens": 4096, + "max_input_tokens": 300000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.000000035, + "output_cost_per_token": 0.00000014, + "litellm_provider": "bedrock_converse", + "mode": "chat", + "supports_function_calling": true, + "supports_vision": true, + "supports_pdf_input": true + }, + "amazon.nova-lite-v1:0": { + "max_tokens": 4096, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.00000006, + "output_cost_per_token": 0.00000024, + "litellm_provider": "bedrock_converse", + "mode": "chat", + "supports_function_calling": true, + "supports_vision": true, + "supports_pdf_input": true + }, + "amazon.nova-pro-v1:0": { + "max_tokens": 4096, + "max_input_tokens": 300000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.0000008, + "output_cost_per_token": 0.0000032, + "litellm_provider": "bedrock_converse", + "mode": "chat", + "supports_function_calling": true, + "supports_vision": true, + "supports_pdf_input": true + }, "anthropic.claude-3-sonnet-20240229-v1:0": { "max_tokens": 4096, "max_input_tokens": 200000, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index b9fae4d254..f2569f2f28 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -3,6 +3,12 @@ model_list: - model_name: gpt-4 litellm_params: model: gpt-4 + rpm: 1 + - model_name: gpt-4 + litellm_params: + model: azure/chatgpt-v-2 + api_key: os.environ/AZURE_API_KEY + api_base: os.environ/AZURE_API_BASE - model_name: rerank-model litellm_params: model: jina_ai/jina-reranker-v2-base-multilingual diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index d2b417c9d4..d170cd92b6 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -667,6 +667,7 @@ class _GenerateKeyRequest(GenerateRequestBase): class GenerateKeyRequest(_GenerateKeyRequest): tags: Optional[List[str]] = None + enforced_params: Optional[List[str]] = None class GenerateKeyResponse(_GenerateKeyRequest): @@ -2190,4 +2191,5 @@ LiteLLM_ManagementEndpoint_MetadataFields = [ "model_tpm_limit", "guardrails", "tags", + "enforced_params", ] diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 21a25c8c1a..179a6fcf54 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -156,40 +156,6 @@ def common_checks( # noqa: PLR0915 raise Exception( f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}" ) - if general_settings.get("enforced_params", None) is not None: - # Enterprise ONLY Feature - # we already validate if user is premium_user when reading the config - # Add an extra premium_usercheck here too, just incase - from litellm.proxy.proxy_server import CommonProxyErrors, premium_user - - if premium_user is not True: - raise ValueError( - "Trying to use `enforced_params`" - + CommonProxyErrors.not_premium_user.value - ) - - if RouteChecks.is_llm_api_route(route=route): - # loop through each enforced param - # example enforced_params ['user', 'metadata', 'metadata.generation_name'] - for enforced_param in general_settings["enforced_params"]: - _enforced_params = enforced_param.split(".") - if len(_enforced_params) == 1: - if _enforced_params[0] not in request_body: - raise ValueError( - f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param" - ) - elif len(_enforced_params) == 2: - # this is a scenario where user requires request['metadata']['generation_name'] to exist - if _enforced_params[0] not in request_body: - raise ValueError( - f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param" - ) - if _enforced_params[1] not in request_body[_enforced_params[0]]: - raise ValueError( - f"BadRequest please pass param=[{_enforced_params[0]}][{_enforced_params[1]}] in request body. This is a required param" - ) - - pass # 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget if ( litellm.max_budget > 0 diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 9e47c20d6f..f690f517c7 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -587,6 +587,16 @@ async def add_litellm_data_to_request( # noqa: PLR0915 f"[PROXY]returned data from litellm_pre_call_utils: {data}" ) + ## ENFORCED PARAMS CHECK + # loop through each enforced param + # example enforced_params ['user', 'metadata', 'metadata.generation_name'] + _enforced_params_check( + request_body=data, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + premium_user=premium_user, + ) + end_time = time.time() await service_logger_obj.async_service_success_hook( service=ServiceTypes.PROXY_PRE_CALL, @@ -599,6 +609,64 @@ async def add_litellm_data_to_request( # noqa: PLR0915 return data +def _get_enforced_params( + general_settings: Optional[dict], user_api_key_dict: UserAPIKeyAuth +) -> Optional[list]: + enforced_params: Optional[list] = None + if general_settings is not None: + enforced_params = general_settings.get("enforced_params") + if "service_account_settings" in general_settings: + service_account_settings = general_settings["service_account_settings"] + if "enforced_params" in service_account_settings: + if enforced_params is None: + enforced_params = [] + enforced_params.extend(service_account_settings["enforced_params"]) + if user_api_key_dict.metadata.get("enforced_params", None) is not None: + if enforced_params is None: + enforced_params = [] + enforced_params.extend(user_api_key_dict.metadata["enforced_params"]) + return enforced_params + + +def _enforced_params_check( + request_body: dict, + general_settings: Optional[dict], + user_api_key_dict: UserAPIKeyAuth, + premium_user: bool, +) -> bool: + """ + If enforced params are set, check if the request body contains the enforced params. + """ + enforced_params: Optional[list] = _get_enforced_params( + general_settings=general_settings, user_api_key_dict=user_api_key_dict + ) + if enforced_params is None: + return True + if enforced_params is not None and premium_user is not True: + raise ValueError( + f"Enforced Params is an Enterprise feature. Enforced Params: {enforced_params}. {CommonProxyErrors.not_premium_user.value}" + ) + + for enforced_param in enforced_params: + _enforced_params = enforced_param.split(".") + if len(_enforced_params) == 1: + if _enforced_params[0] not in request_body: + raise ValueError( + f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param" + ) + elif len(_enforced_params) == 2: + # this is a scenario where user requires request['metadata']['generation_name'] to exist + if _enforced_params[0] not in request_body: + raise ValueError( + f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param" + ) + if _enforced_params[1] not in request_body[_enforced_params[0]]: + raise ValueError( + f"BadRequest please pass param=[{_enforced_params[0]}][{_enforced_params[1]}] in request body. This is a required param" + ) + return True + + def move_guardrails_to_metadata( data: dict, _metadata_variable_name: str, diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 287de56964..ee1b9bd8b3 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -255,6 +255,7 @@ async def generate_key_fn( # noqa: PLR0915 - tpm_limit: Optional[int] - Specify tpm limit for a given key (Tokens per minute) - soft_budget: Optional[float] - Specify soft budget for a given key. Will trigger a slack alert when this soft budget is reached. - tags: Optional[List[str]] - Tags for [tracking spend](https://litellm.vercel.app/docs/proxy/enterprise#tracking-spend-for-custom-tags) and/or doing [tag-based routing](https://litellm.vercel.app/docs/proxy/tag_routing). + - enforced_params: Optional[List[str]] - List of enforced params for the key (Enterprise only). [Docs](https://docs.litellm.ai/docs/proxy/enterprise#enforce-required-params-for-llm-requests) Examples: @@ -459,7 +460,6 @@ def prepare_metadata_fields( """ Check LiteLLM_ManagementEndpoint_MetadataFields (proxy/_types.py) for fields that are allowed to be updated """ - if "metadata" not in non_default_values: # allow user to set metadata to none non_default_values["metadata"] = existing_metadata.copy() @@ -469,18 +469,8 @@ def prepare_metadata_fields( try: for k, v in data_json.items(): - if k == "model_tpm_limit" or k == "model_rpm_limit": - if k not in casted_metadata or casted_metadata[k] is None: - casted_metadata[k] = {} - casted_metadata[k].update(v) - - if k == "tags" or k == "guardrails": - if k not in casted_metadata or casted_metadata[k] is None: - casted_metadata[k] = [] - seen = set(casted_metadata[k]) - casted_metadata[k].extend( - x for x in v if x not in seen and not seen.add(x) # type: ignore - ) # prevent duplicates from being added + maintain initial order + if k in LiteLLM_ManagementEndpoint_MetadataFields: + casted_metadata[k] = v except Exception as e: verbose_proxy_logger.exception( @@ -498,10 +488,9 @@ def prepare_key_update_data( ): data_json: dict = data.model_dump(exclude_unset=True) data_json.pop("key", None) - _metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails", "tags"] non_default_values = {} for k, v in data_json.items(): - if k in _metadata_fields: + if k in LiteLLM_ManagementEndpoint_MetadataFields: continue non_default_values[k] = v @@ -556,6 +545,7 @@ async def update_key_fn( - team_id: Optional[str] - Team ID associated with key - models: Optional[list] - Model_name's a user is allowed to call - tags: Optional[List[str]] - Tags for organizing keys (Enterprise only) + - enforced_params: Optional[List[str]] - List of enforced params for the key (Enterprise only). [Docs](https://docs.litellm.ai/docs/proxy/enterprise#enforce-required-params-for-llm-requests) - spend: Optional[float] - Amount spent by key - max_budget: Optional[float] - Max budget for key - model_max_budget: Optional[dict] - Model-specific budgets {"gpt-4": 0.5, "claude-v1": 1.0} diff --git a/litellm/router.py b/litellm/router.py index bbee4d7bb2..800c7a9b95 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -2940,6 +2940,7 @@ class Router: remaining_retries=num_retries, num_retries=num_retries, healthy_deployments=_healthy_deployments, + all_deployments=_all_deployments, ) await asyncio.sleep(retry_after) @@ -2972,6 +2973,7 @@ class Router: remaining_retries=remaining_retries, num_retries=num_retries, healthy_deployments=_healthy_deployments, + all_deployments=_all_deployments, ) await asyncio.sleep(_timeout) @@ -3149,6 +3151,7 @@ class Router: remaining_retries: int, num_retries: int, healthy_deployments: Optional[List] = None, + all_deployments: Optional[List] = None, ) -> Union[int, float]: """ Calculate back-off, then retry @@ -3157,10 +3160,14 @@ class Router: 1. there are healthy deployments in the same model group 2. there are fallbacks for the completion call """ - if ( + + ## base case - single deployment + if all_deployments is not None and len(all_deployments) == 1: + pass + elif ( healthy_deployments is not None and isinstance(healthy_deployments, list) - and len(healthy_deployments) > 1 + and len(healthy_deployments) > 0 ): return 0 @@ -3242,6 +3249,7 @@ class Router: remaining_retries=num_retries, num_retries=num_retries, healthy_deployments=_healthy_deployments, + all_deployments=_all_deployments, ) ## LOGGING @@ -3276,6 +3284,7 @@ class Router: remaining_retries=remaining_retries, num_retries=num_retries, healthy_deployments=_healthy_deployments, + all_deployments=_all_deployments, ) time.sleep(_timeout) diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py index c80b16f6e9..88f329adeb 100644 --- a/litellm/types/llms/bedrock.py +++ b/litellm/types/llms/bedrock.py @@ -18,17 +18,24 @@ class SystemContentBlock(TypedDict): text: str -class ImageSourceBlock(TypedDict): +class SourceBlock(TypedDict): bytes: Optional[str] # base 64 encoded string class ImageBlock(TypedDict): format: Literal["png", "jpeg", "gif", "webp"] - source: ImageSourceBlock + source: SourceBlock + + +class DocumentBlock(TypedDict): + format: Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"] + source: SourceBlock + name: str class ToolResultContentBlock(TypedDict, total=False): image: ImageBlock + document: DocumentBlock json: dict text: str @@ -48,6 +55,7 @@ class ToolUseBlock(TypedDict): class ContentBlock(TypedDict, total=False): text: str image: ImageBlock + document: DocumentBlock toolResult: ToolResultBlock toolUse: ToolUseBlock diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 93b4a39d3b..3467e1a107 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -106,6 +106,7 @@ class ModelInfo(TypedDict, total=False): supports_prompt_caching: Optional[bool] supports_audio_input: Optional[bool] supports_audio_output: Optional[bool] + supports_pdf_input: Optional[bool] tpm: Optional[int] rpm: Optional[int] diff --git a/litellm/utils.py b/litellm/utils.py index 946d81982b..86c0a60294 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3142,7 +3142,7 @@ def get_optional_params( # noqa: PLR0915 model=model, custom_llm_provider=custom_llm_provider ) base_model = litellm.AmazonConverseConfig()._get_base_model(model) - if base_model in litellm.BEDROCK_CONVERSE_MODELS: + if base_model in litellm.bedrock_converse_models: _check_valid_arg(supported_params=supported_params) optional_params = litellm.AmazonConverseConfig().map_openai_params( model=model, @@ -4308,6 +4308,10 @@ def _strip_stable_vertex_version(model_name) -> str: return re.sub(r"-\d+$", "", model_name) +def _strip_bedrock_region(model_name) -> str: + return litellm.AmazonConverseConfig()._get_base_model(model_name) + + def _strip_openai_finetune_model_name(model_name: str) -> str: """ Strips the organization, custom suffix, and ID from an OpenAI fine-tuned model name. @@ -4324,16 +4328,50 @@ def _strip_openai_finetune_model_name(model_name: str) -> str: return re.sub(r"(:[^:]+){3}$", "", model_name) -def _strip_model_name(model: str) -> str: - strip_version = _strip_stable_vertex_version(model_name=model) - strip_finetune = _strip_openai_finetune_model_name(model_name=strip_version) - return strip_finetune +def _strip_model_name(model: str, custom_llm_provider: Optional[str]) -> str: + if custom_llm_provider and custom_llm_provider == "bedrock": + strip_bedrock_region = _strip_bedrock_region(model_name=model) + return strip_bedrock_region + elif custom_llm_provider and ( + custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini" + ): + strip_version = _strip_stable_vertex_version(model_name=model) + return strip_version + else: + strip_finetune = _strip_openai_finetune_model_name(model_name=model) + return strip_finetune def _get_model_info_from_model_cost(key: str) -> dict: return litellm.model_cost[key] +def _check_provider_match(model_info: dict, custom_llm_provider: Optional[str]) -> bool: + """ + Check if the model info provider matches the custom provider. + """ + if custom_llm_provider and ( + "litellm_provider" in model_info + and model_info["litellm_provider"] != custom_llm_provider + ): + if custom_llm_provider == "vertex_ai" and model_info[ + "litellm_provider" + ].startswith("vertex_ai"): + return True + elif custom_llm_provider == "fireworks_ai" and model_info[ + "litellm_provider" + ].startswith("fireworks_ai"): + return True + elif custom_llm_provider == "bedrock" and model_info[ + "litellm_provider" + ].startswith("bedrock"): + return True + else: + return False + + return True + + def get_model_info( # noqa: PLR0915 model: str, custom_llm_provider: Optional[str] = None ) -> ModelInfo: @@ -4388,6 +4426,7 @@ def get_model_info( # noqa: PLR0915 supports_prompt_caching: Optional[bool] supports_audio_input: Optional[bool] supports_audio_output: Optional[bool] + supports_pdf_input: Optional[bool] Raises: Exception: If the model is not mapped yet. @@ -4445,15 +4484,21 @@ def get_model_info( # noqa: PLR0915 except Exception: split_model = model combined_model_name = model - stripped_model_name = _strip_model_name(model=model) + stripped_model_name = _strip_model_name( + model=model, custom_llm_provider=custom_llm_provider + ) combined_stripped_model_name = stripped_model_name else: split_model = model combined_model_name = "{}/{}".format(custom_llm_provider, model) - stripped_model_name = _strip_model_name(model=model) - combined_stripped_model_name = "{}/{}".format( - custom_llm_provider, _strip_model_name(model=model) + stripped_model_name = _strip_model_name( + model=model, custom_llm_provider=custom_llm_provider ) + combined_stripped_model_name = "{}/{}".format( + custom_llm_provider, + _strip_model_name(model=model, custom_llm_provider=custom_llm_provider), + ) + ######################### supported_openai_params = litellm.get_supported_openai_params( @@ -4476,6 +4521,7 @@ def get_model_info( # noqa: PLR0915 supports_function_calling=None, supports_assistant_prefill=None, supports_prompt_caching=None, + supports_pdf_input=None, ) elif custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat": return litellm.OllamaConfig().get_model_info(model) @@ -4488,40 +4534,25 @@ def get_model_info( # noqa: PLR0915 4. 'stripped_model_name' in litellm.model_cost. Checks if 'ft:gpt-3.5-turbo' in model map, if 'ft:gpt-3.5-turbo:my-org:custom_suffix:id' given. 5. 'split_model' in litellm.model_cost. Checks "llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192" """ + _model_info: Optional[Dict[str, Any]] = None key: Optional[str] = None if combined_model_name in litellm.model_cost: key = combined_model_name _model_info = _get_model_info_from_model_cost(key=key) _model_info["supported_openai_params"] = supported_openai_params - if ( - "litellm_provider" in _model_info - and _model_info["litellm_provider"] != custom_llm_provider + if not _check_provider_match( + model_info=_model_info, custom_llm_provider=custom_llm_provider ): - if custom_llm_provider == "vertex_ai" and _model_info[ - "litellm_provider" - ].startswith("vertex_ai"): - pass - else: - _model_info = None + _model_info = None if _model_info is None and model in litellm.model_cost: key = model _model_info = _get_model_info_from_model_cost(key=key) _model_info["supported_openai_params"] = supported_openai_params - if ( - "litellm_provider" in _model_info - and _model_info["litellm_provider"] != custom_llm_provider + if not _check_provider_match( + model_info=_model_info, custom_llm_provider=custom_llm_provider ): - if custom_llm_provider == "vertex_ai" and _model_info[ - "litellm_provider" - ].startswith("vertex_ai"): - pass - elif custom_llm_provider == "fireworks_ai" and _model_info[ - "litellm_provider" - ].startswith("fireworks_ai"): - pass - else: - _model_info = None + _model_info = None if ( _model_info is None and combined_stripped_model_name in litellm.model_cost @@ -4529,57 +4560,26 @@ def get_model_info( # noqa: PLR0915 key = combined_stripped_model_name _model_info = _get_model_info_from_model_cost(key=key) _model_info["supported_openai_params"] = supported_openai_params - if ( - "litellm_provider" in _model_info - and _model_info["litellm_provider"] != custom_llm_provider + if not _check_provider_match( + model_info=_model_info, custom_llm_provider=custom_llm_provider ): - if custom_llm_provider == "vertex_ai" and _model_info[ - "litellm_provider" - ].startswith("vertex_ai"): - pass - elif custom_llm_provider == "fireworks_ai" and _model_info[ - "litellm_provider" - ].startswith("fireworks_ai"): - pass - else: - _model_info = None + _model_info = None if _model_info is None and stripped_model_name in litellm.model_cost: key = stripped_model_name _model_info = _get_model_info_from_model_cost(key=key) _model_info["supported_openai_params"] = supported_openai_params - if ( - "litellm_provider" in _model_info - and _model_info["litellm_provider"] != custom_llm_provider + if not _check_provider_match( + model_info=_model_info, custom_llm_provider=custom_llm_provider ): - if custom_llm_provider == "vertex_ai" and _model_info[ - "litellm_provider" - ].startswith("vertex_ai"): - pass - elif custom_llm_provider == "fireworks_ai" and _model_info[ - "litellm_provider" - ].startswith("fireworks_ai"): - pass - else: - _model_info = None - + _model_info = None if _model_info is None and split_model in litellm.model_cost: key = split_model _model_info = _get_model_info_from_model_cost(key=key) _model_info["supported_openai_params"] = supported_openai_params - if ( - "litellm_provider" in _model_info - and _model_info["litellm_provider"] != custom_llm_provider + if not _check_provider_match( + model_info=_model_info, custom_llm_provider=custom_llm_provider ): - if custom_llm_provider == "vertex_ai" and _model_info[ - "litellm_provider" - ].startswith("vertex_ai"): - pass - elif custom_llm_provider == "fireworks_ai" and _model_info[ - "litellm_provider" - ].startswith("fireworks_ai"): - pass - else: - _model_info = None + _model_info = None if _model_info is None or key is None: raise ValueError( "This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" @@ -4675,6 +4675,7 @@ def get_model_info( # noqa: PLR0915 ), supports_audio_input=_model_info.get("supports_audio_input", False), supports_audio_output=_model_info.get("supports_audio_output", False), + supports_pdf_input=_model_info.get("supports_pdf_input", False), tpm=_model_info.get("tpm", None), rpm=_model_info.get("rpm", None), ) @@ -6195,11 +6196,21 @@ class ProviderConfigManager: return OpenAIGPTConfig() -def get_end_user_id_for_cost_tracking(litellm_params: dict) -> Optional[str]: +def get_end_user_id_for_cost_tracking( + litellm_params: dict, + service_type: Literal["litellm_logging", "prometheus"] = "litellm_logging", +) -> Optional[str]: """ Used for enforcing `disable_end_user_cost_tracking` param. + + service_type: "litellm_logging" or "prometheus" - used to allow prometheus only disable cost tracking. """ proxy_server_request = litellm_params.get("proxy_server_request") or {} if litellm.disable_end_user_cost_tracking: return None + if ( + service_type == "prometheus" + and litellm.disable_end_user_cost_tracking_prometheus_only + ): + return None return proxy_server_request.get("body", {}).get("user", None) diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index ac22871bcc..4d89091c4d 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -4795,6 +4795,42 @@ "mode": "chat", "supports_function_calling": true }, + "amazon.nova-micro-v1:0": { + "max_tokens": 4096, + "max_input_tokens": 300000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.000000035, + "output_cost_per_token": 0.00000014, + "litellm_provider": "bedrock_converse", + "mode": "chat", + "supports_function_calling": true, + "supports_vision": true, + "supports_pdf_input": true + }, + "amazon.nova-lite-v1:0": { + "max_tokens": 4096, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.00000006, + "output_cost_per_token": 0.00000024, + "litellm_provider": "bedrock_converse", + "mode": "chat", + "supports_function_calling": true, + "supports_vision": true, + "supports_pdf_input": true + }, + "amazon.nova-pro-v1:0": { + "max_tokens": 4096, + "max_input_tokens": 300000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.0000008, + "output_cost_per_token": 0.0000032, + "litellm_provider": "bedrock_converse", + "mode": "chat", + "supports_function_calling": true, + "supports_vision": true, + "supports_pdf_input": true + }, "anthropic.claude-3-sonnet-20240229-v1:0": { "max_tokens": 4096, "max_input_tokens": 200000, diff --git a/tests/llm_translation/base_llm_unit_tests.py b/tests/llm_translation/base_llm_unit_tests.py index a4551378eb..5004d45994 100644 --- a/tests/llm_translation/base_llm_unit_tests.py +++ b/tests/llm_translation/base_llm_unit_tests.py @@ -54,6 +54,36 @@ class BaseLLMChatTest(ABC): # for OpenAI the content contains the JSON schema, so we need to assert that the content is not None assert response.choices[0].message.content is not None + @pytest.mark.parametrize("image_url", ["str", "dict"]) + def test_pdf_handling(self, pdf_messages, image_url): + from litellm.utils import supports_pdf_input + + if image_url == "str": + image_url = pdf_messages + elif image_url == "dict": + image_url = {"url": pdf_messages} + + image_content = [ + {"type": "text", "text": "What's this file about?"}, + { + "type": "image_url", + "image_url": image_url, + }, + ] + + image_messages = [{"role": "user", "content": image_content}] + + base_completion_call_args = self.get_base_completion_call_args() + + if not supports_pdf_input(base_completion_call_args["model"], None): + pytest.skip("Model does not support image input") + + response = litellm.completion( + **base_completion_call_args, + messages=image_messages, + ) + assert response is not None + def test_message_with_name(self): litellm.set_verbose = True base_completion_call_args = self.get_base_completion_call_args() @@ -187,7 +217,7 @@ class BaseLLMChatTest(ABC): for chunk in response: content += chunk.choices[0].delta.content or "" - print("content=", content) + print(f"content={content}") # OpenAI guarantees that the JSON schema is returned in the content # relevant issue: https://github.com/BerriAI/litellm/issues/6741 @@ -250,7 +280,7 @@ class BaseLLMChatTest(ABC): import requests # URL of the file - url = "https://storage.googleapis.com/cloud-samples-data/generative-ai/pdf/2403.05530.pdf" + url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf" response = requests.get(url) file_data = response.content @@ -258,14 +288,4 @@ class BaseLLMChatTest(ABC): encoded_file = base64.b64encode(file_data).decode("utf-8") url = f"data:application/pdf;base64,{encoded_file}" - image_content = [ - {"type": "text", "text": "What's this file about?"}, - { - "type": "image_url", - "image_url": {"url": url}, - }, - ] - - image_messages = [{"role": "user", "content": image_content}] - - return image_messages + return url diff --git a/tests/llm_translation/test_anthropic_completion.py b/tests/llm_translation/test_anthropic_completion.py index a61bdfd731..b5a10953d4 100644 --- a/tests/llm_translation/test_anthropic_completion.py +++ b/tests/llm_translation/test_anthropic_completion.py @@ -668,35 +668,6 @@ class TestAnthropicCompletion(BaseLLMChatTest): def get_base_completion_call_args(self) -> dict: return {"model": "claude-3-haiku-20240307"} - def test_pdf_handling(self, pdf_messages): - from litellm.llms.custom_httpx.http_handler import HTTPHandler - from litellm.types.llms.anthropic import AnthropicMessagesDocumentParam - import json - - client = HTTPHandler() - - with patch.object(client, "post", new=MagicMock()) as mock_client: - response = completion( - model="claude-3-5-sonnet-20241022", - messages=pdf_messages, - client=client, - ) - - mock_client.assert_called_once() - - json_data = json.loads(mock_client.call_args.kwargs["data"]) - headers = mock_client.call_args.kwargs["headers"] - - assert headers["anthropic-beta"] == "pdfs-2024-09-25" - - json_data["messages"][0]["role"] == "user" - _document_validation = AnthropicMessagesDocumentParam( - **json_data["messages"][0]["content"][1] - ) - assert _document_validation["type"] == "document" - assert _document_validation["source"]["media_type"] == "application/pdf" - assert _document_validation["source"]["type"] == "base64" - def test_tool_call_no_arguments(self, tool_call_no_arguments): """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" from litellm.llms.prompt_templates.factory import ( diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py index b2054dc232..a89b1c943d 100644 --- a/tests/llm_translation/test_bedrock_completion.py +++ b/tests/llm_translation/test_bedrock_completion.py @@ -30,6 +30,7 @@ from litellm import ( from litellm.llms.bedrock.chat import BedrockLLM from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.prompt_templates.factory import _bedrock_tools_pt +from base_llm_unit_tests import BaseLLMChatTest # litellm.num_retries = 3 litellm.cache = None @@ -1943,3 +1944,42 @@ def test_bedrock_context_window_error(): messages=[{"role": "user", "content": "Hello, world!"}], mock_response=Exception("prompt is too long"), ) + + +def test_bedrock_converse_route(): + litellm.set_verbose = True + litellm.completion( + model="bedrock/converse/us.amazon.nova-pro-v1:0", + messages=[{"role": "user", "content": "Hello, world!"}], + ) + + +def test_bedrock_mapped_converse_models(): + litellm.set_verbose = True + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.add_known_models() + litellm.completion( + model="bedrock/us.amazon.nova-pro-v1:0", + messages=[{"role": "user", "content": "Hello, world!"}], + ) + + +def test_bedrock_base_model_helper(): + model = "us.amazon.nova-pro-v1:0" + litellm.AmazonConverseConfig()._get_base_model(model) + assert model == "us.amazon.nova-pro-v1:0" + + +class TestBedrockConverseChat(BaseLLMChatTest): + def get_base_completion_call_args(self) -> dict: + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + litellm.model_cost = litellm.get_model_cost_map(url="") + litellm.add_known_models() + return { + "model": "bedrock/us.anthropic.claude-3-haiku-20240307-v1:0", + } + + def test_tool_call_no_arguments(self, tool_call_no_arguments): + """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" + pass diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index f69778e484..6f42d9a396 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -695,29 +695,6 @@ async def test_anthropic_no_content_error(): pytest.fail(f"An unexpected error occurred - {str(e)}") -def test_gemini_completion_call_error(): - try: - print("test completion + streaming") - litellm.num_retries = 3 - litellm.set_verbose = True - messages = [{"role": "user", "content": "what is the capital of congo?"}] - response = completion( - model="gemini/gemini-1.5-pro-latest", - messages=messages, - stream=True, - max_tokens=10, - ) - print(f"response: {response}") - for chunk in response: - print(chunk) - except litellm.RateLimitError: - pass - except litellm.InternalServerError: - pass - except Exception as e: - pytest.fail(f"error occurred: {str(e)}") - - def test_completion_cohere_command_r_plus_function_call(): litellm.set_verbose = True tools = [ diff --git a/tests/local_testing/test_router.py b/tests/local_testing/test_router.py index 7b53d42db0..d3db083f68 100644 --- a/tests/local_testing/test_router.py +++ b/tests/local_testing/test_router.py @@ -2342,12 +2342,6 @@ def test_router_dynamic_cooldown_correct_retry_after_time(): pass new_retry_after_mock_client.assert_called() - print( - f"new_retry_after_mock_client.call_args.kwargs: {new_retry_after_mock_client.call_args.kwargs}" - ) - print( - f"new_retry_after_mock_client.call_args: {new_retry_after_mock_client.call_args[0][0]}" - ) response_headers: httpx.Headers = new_retry_after_mock_client.call_args[0][0] assert int(response_headers["retry-after"]) == cooldown_time diff --git a/tests/local_testing/test_router_retries.py b/tests/local_testing/test_router_retries.py index 6922f55ca5..24b46b6549 100644 --- a/tests/local_testing/test_router_retries.py +++ b/tests/local_testing/test_router_retries.py @@ -539,45 +539,71 @@ def test_raise_context_window_exceeded_error_no_retry(): """ -def test_timeout_for_rate_limit_error_with_healthy_deployments(): +@pytest.mark.parametrize("num_deployments, expected_timeout", [(1, 60), (2, 0.0)]) +def test_timeout_for_rate_limit_error_with_healthy_deployments( + num_deployments, expected_timeout +): """ Test 1. Timeout is 0.0 when RateLimit Error and healthy deployments are > 0 """ - healthy_deployments = [ - "deployment1", - "deployment2", - ] # multiple healthy deployments mocked up - fallbacks = None - - router = litellm.Router( - model_list=[ - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "azure/chatgpt-v-2", - "api_key": os.getenv("AZURE_API_KEY"), - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE"), - }, - } - ] + cooldown_time = 60 + rate_limit_error = litellm.RateLimitError( + message="{RouterErrors.no_deployments_available.value}. 12345 Passed model={model_group}. Deployments={deployment_dict}", + llm_provider="", + model="gpt-3.5-turbo", + response=httpx.Response( + status_code=429, + content="", + headers={"retry-after": str(cooldown_time)}, # type: ignore + request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore + ), ) + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + } + ] + if num_deployments == 2: + model_list.append( + { + "model_name": "gpt-4", + "litellm_params": {"model": "gpt-3.5-turbo"}, + } + ) + + router = litellm.Router(model_list=model_list) _timeout = router._time_to_sleep_before_retry( e=rate_limit_error, - remaining_retries=4, - num_retries=4, - healthy_deployments=healthy_deployments, + remaining_retries=2, + num_retries=2, + healthy_deployments=[ + { + "model_name": "gpt-4", + "litellm_params": { + "api_key": "my-key", + "api_base": "https://openai-gpt-4-test-v-1.openai.azure.com", + "model": "azure/chatgpt-v-2", + }, + "model_info": { + "id": "0e30bc8a63fa91ae4415d4234e231b3f9e6dd900cac57d118ce13a720d95e9d6", + "db_model": False, + }, + } + ], + all_deployments=model_list, ) - print( - "timeout=", - _timeout, - "error is rate_limit_error and there are healthy deployments=", - healthy_deployments, - ) - - assert _timeout == 0.0 + if expected_timeout == 0.0: + assert _timeout == expected_timeout + else: + assert _timeout > 0.0 def test_timeout_for_rate_limit_error_with_no_healthy_deployments(): @@ -585,26 +611,26 @@ def test_timeout_for_rate_limit_error_with_no_healthy_deployments(): Test 2. Timeout is > 0.0 when RateLimit Error and healthy deployments == 0 """ healthy_deployments = [] + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + } + ] - router = litellm.Router( - model_list=[ - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "azure/chatgpt-v-2", - "api_key": os.getenv("AZURE_API_KEY"), - "api_version": os.getenv("AZURE_API_VERSION"), - "api_base": os.getenv("AZURE_API_BASE"), - }, - } - ] - ) + router = litellm.Router(model_list=model_list) _timeout = router._time_to_sleep_before_retry( e=rate_limit_error, remaining_retries=4, num_retries=4, healthy_deployments=healthy_deployments, + all_deployments=model_list, ) print( diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py index 7c349a6580..54c2e8c6c3 100644 --- a/tests/local_testing/test_utils.py +++ b/tests/local_testing/test_utils.py @@ -1005,7 +1005,10 @@ def test_models_by_provider(): continue elif k == "sample_spec": continue - elif v["litellm_provider"] == "sagemaker": + elif ( + v["litellm_provider"] == "sagemaker" + or v["litellm_provider"] == "bedrock_converse" + ): continue else: providers.add(v["litellm_provider"]) @@ -1032,3 +1035,27 @@ def test_get_end_user_id_for_cost_tracking( get_end_user_id_for_cost_tracking(litellm_params=litellm_params) == expected_end_user_id ) + + +@pytest.mark.parametrize( + "litellm_params, disable_end_user_cost_tracking_prometheus_only, expected_end_user_id", + [ + ({}, False, None), + ({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"), + ({"proxy_server_request": {"body": {"user": "123"}}}, True, None), + ], +) +def test_get_end_user_id_for_cost_tracking_prometheus_only( + litellm_params, disable_end_user_cost_tracking_prometheus_only, expected_end_user_id +): + from litellm.utils import get_end_user_id_for_cost_tracking + + litellm.disable_end_user_cost_tracking_prometheus_only = ( + disable_end_user_cost_tracking_prometheus_only + ) + assert ( + get_end_user_id_for_cost_tracking( + litellm_params=litellm_params, service_type="prometheus" + ) + == expected_end_user_id + ) diff --git a/tests/proxy_admin_ui_tests/test_key_management.py b/tests/proxy_admin_ui_tests/test_key_management.py index 7a2764e3f2..e1fe3e4d30 100644 --- a/tests/proxy_admin_ui_tests/test_key_management.py +++ b/tests/proxy_admin_ui_tests/test_key_management.py @@ -695,45 +695,43 @@ def test_personal_key_generation_check(): ) -def test_prepare_metadata_fields(): +@pytest.mark.parametrize( + "update_request_data, non_default_values, existing_metadata, expected_result", + [ + ( + {"metadata": {"test": "new"}}, + {"metadata": {"test": "new"}}, + {"test": "test"}, + {"metadata": {"test": "new"}}, + ), + ( + {"tags": ["new_tag"]}, + {}, + {"tags": ["old_tag"]}, + {"metadata": {"tags": ["new_tag"]}}, + ), + ( + {"enforced_params": ["metadata.tags"]}, + {}, + {"tags": ["old_tag"]}, + {"metadata": {"tags": ["old_tag"], "enforced_params": ["metadata.tags"]}}, + ), + ], +) +def test_prepare_metadata_fields( + update_request_data, non_default_values, existing_metadata, expected_result +): from litellm.proxy.management_endpoints.key_management_endpoints import ( prepare_metadata_fields, ) - new_metadata = {"test": "new"} - old_metadata = {"test": "test"} - args = { "data": UpdateKeyRequest( - key_alias=None, - duration=None, - models=[], - spend=None, - max_budget=None, - user_id=None, - team_id=None, - max_parallel_requests=None, - metadata=new_metadata, - tpm_limit=None, - rpm_limit=None, - budget_duration=None, - allowed_cache_controls=[], - soft_budget=None, - config={}, - permissions={}, - model_max_budget={}, - send_invite_email=None, - model_rpm_limit=None, - model_tpm_limit=None, - guardrails=None, - blocked=None, - aliases={}, - key="sk-1qGQUJJTcljeaPfzgWRrXQ", - tags=None, + key="sk-1qGQUJJTcljeaPfzgWRrXQ", **update_request_data ), - "non_default_values": {"metadata": new_metadata}, - "existing_metadata": {"tags": None, **old_metadata}, + "non_default_values": non_default_values, + "existing_metadata": existing_metadata, } - non_default_values = prepare_metadata_fields(**args) - assert non_default_values == {"metadata": new_metadata} + updated_non_default_values = prepare_metadata_fields(**args) + assert updated_non_default_values == expected_result diff --git a/tests/proxy_unit_tests/test_key_generate_prisma.py b/tests/proxy_unit_tests/test_key_generate_prisma.py index e1720654b1..2c8ba5b2ab 100644 --- a/tests/proxy_unit_tests/test_key_generate_prisma.py +++ b/tests/proxy_unit_tests/test_key_generate_prisma.py @@ -2664,66 +2664,6 @@ async def test_create_update_team(prisma_client): ) -@pytest.mark.asyncio() -async def test_enforced_params(prisma_client): - setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) - setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") - from litellm.proxy.proxy_server import general_settings - - general_settings["enforced_params"] = [ - "user", - "metadata", - "metadata.generation_name", - ] - - await litellm.proxy.proxy_server.prisma_client.connect() - request = NewUserRequest() - key = await new_user( - data=request, - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="1234", - ), - ) - print(key) - - generated_key = key.key - bearer_token = "Bearer " + generated_key - - request = Request(scope={"type": "http"}) - request._url = URL(url="/chat/completions") - - # Case 1: Missing user - async def return_body(): - return b'{"model": "gemini-pro-vision"}' - - request.body = return_body - try: - await user_api_key_auth(request=request, api_key=bearer_token) - pytest.fail(f"This should have failed!. IT's an invalid request") - except Exception as e: - assert ( - "BadRequest please pass param=user in request body. This is a required param" - in e.message - ) - - # Case 2: Missing metadata["generation_name"] - async def return_body_2(): - return b'{"model": "gemini-pro-vision", "user": "1234", "metadata": {}}' - - request.body = return_body_2 - try: - await user_api_key_auth(request=request, api_key=bearer_token) - pytest.fail(f"This should have failed!. IT's an invalid request") - except Exception as e: - assert ( - "Authentication Error, BadRequest please pass param=[metadata][generation_name] in request body" - in e.message - ) - general_settings.pop("enforced_params") - - @pytest.mark.asyncio() async def test_update_user_role(prisma_client): """ @@ -3363,64 +3303,6 @@ async def test_auth_vertex_ai_route(prisma_client): pass -@pytest.mark.asyncio -async def test_service_accounts(prisma_client): - """ - Do not delete - this is the Admin UI flow - """ - # Make a call to a key with model = `all-proxy-models` this is an Alias from LiteLLM Admin UI - setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) - setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") - setattr( - litellm.proxy.proxy_server, - "general_settings", - {"service_account_settings": {"enforced_params": ["user"]}}, - ) - - await litellm.proxy.proxy_server.prisma_client.connect() - - request = GenerateKeyRequest( - metadata={"service_account_id": f"prod-service-{uuid.uuid4()}"}, - ) - response = await generate_key_fn( - data=request, - ) - - print("key generated=", response) - generated_key = response.key - bearer_token = "Bearer " + generated_key - # make a bad /chat/completions call expect it to fail - - request = Request(scope={"type": "http"}) - request._url = URL(url="/chat/completions") - - async def return_body(): - return b'{"model": "gemini-pro-vision"}' - - request.body = return_body - - # use generated key to auth in - print("Bearer token being sent to user_api_key_auth() - {}".format(bearer_token)) - try: - result = await user_api_key_auth(request=request, api_key=bearer_token) - pytest.fail("Expected this call to fail. Bad request using service account") - except Exception as e: - print("error str=", str(e.message)) - assert "This is a required param for service account" in str(e.message) - - # make a good /chat/completions call it should pass - async def good_return_body(): - return b'{"model": "gemini-pro-vision", "user": "foo"}' - - request.body = good_return_body - - result = await user_api_key_auth(request=request, api_key=bearer_token) - print("response from user_api_key_auth", result) - - setattr(litellm.proxy.proxy_server, "general_settings", {}) - - @pytest.mark.asyncio async def test_user_api_key_auth_db_unavailable(): """ diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index 6de47b6eed..7910cc7ea9 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -679,3 +679,132 @@ async def test_add_litellm_data_to_request_duplicate_tags( assert sorted(result["metadata"]["tags"]) == sorted( expected_tags ), f"Expected {expected_tags}, got {result['metadata']['tags']}" + + +@pytest.mark.parametrize( + "general_settings, user_api_key_dict, expected_enforced_params", + [ + ( + {"enforced_params": ["param1", "param2"]}, + UserAPIKeyAuth( + api_key="test_api_key", user_id="test_user_id", org_id="test_org_id" + ), + ["param1", "param2"], + ), + ( + {"service_account_settings": {"enforced_params": ["param1", "param2"]}}, + UserAPIKeyAuth( + api_key="test_api_key", user_id="test_user_id", org_id="test_org_id" + ), + ["param1", "param2"], + ), + ( + {"service_account_settings": {"enforced_params": ["param1", "param2"]}}, + UserAPIKeyAuth( + api_key="test_api_key", + metadata={"enforced_params": ["param3", "param4"]}, + ), + ["param1", "param2", "param3", "param4"], + ), + ], +) +def test_get_enforced_params( + general_settings, user_api_key_dict, expected_enforced_params +): + from litellm.proxy.litellm_pre_call_utils import _get_enforced_params + + enforced_params = _get_enforced_params(general_settings, user_api_key_dict) + assert enforced_params == expected_enforced_params + + +@pytest.mark.parametrize( + "general_settings, user_api_key_dict, request_body, expected_error", + [ + ( + {"enforced_params": ["param1", "param2"]}, + UserAPIKeyAuth( + api_key="test_api_key", user_id="test_user_id", org_id="test_org_id" + ), + {}, + True, + ), + ( + {"service_account_settings": {"enforced_params": ["user"]}}, + UserAPIKeyAuth( + api_key="test_api_key", user_id="test_user_id", org_id="test_org_id" + ), + {}, + True, + ), + ( + {}, + UserAPIKeyAuth( + api_key="test_api_key", + metadata={"enforced_params": ["user"]}, + ), + {}, + True, + ), + ( + {}, + UserAPIKeyAuth( + api_key="test_api_key", + metadata={"enforced_params": ["user"]}, + ), + {"user": "test_user"}, + False, + ), + ( + {"enforced_params": ["user"]}, + UserAPIKeyAuth( + api_key="test_api_key", + ), + {"user": "test_user"}, + False, + ), + ( + {"service_account_settings": {"enforced_params": ["user"]}}, + UserAPIKeyAuth( + api_key="test_api_key", + ), + {"user": "test_user"}, + False, + ), + ( + {"enforced_params": ["metadata.generation_name"]}, + UserAPIKeyAuth( + api_key="test_api_key", + ), + {"metadata": {}}, + True, + ), + ( + {"enforced_params": ["metadata.generation_name"]}, + UserAPIKeyAuth( + api_key="test_api_key", + ), + {"metadata": {"generation_name": "test_generation_name"}}, + False, + ), + ], +) +def test_enforced_params_check( + general_settings, user_api_key_dict, request_body, expected_error +): + from litellm.proxy.litellm_pre_call_utils import _enforced_params_check + + if expected_error: + with pytest.raises(ValueError): + _enforced_params_check( + request_body=request_body, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + premium_user=True, + ) + else: + _enforced_params_check( + request_body=request_body, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + premium_user=True, + )