From d46660ea0f4a4c888e618456363a04b60077c41f Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Thu, 19 Sep 2024 13:25:29 -0700 Subject: [PATCH] LiteLLM Minor Fixes & Improvements (09/18/2024) (#5772) * fix(proxy_server.py): fix azure key vault logic to not require client id/secret * feat(cost_calculator.py): support fireworks ai cost tracking * build(docker-compose.yml): add lines for mounting config.yaml to docker compose Closes https://github.com/BerriAI/litellm/issues/5739 * fix(input.md): update docs to clarify litellm supports content as a list of dictionaries Fixes https://github.com/BerriAI/litellm/issues/5755 * fix(input.md): update input.md to include all message values * fix(image_handling.py): follow image url redirects Fixes https://github.com/BerriAI/litellm/issues/5763 * fix(router.py): Fix model key/base leak in error message Fixes https://github.com/BerriAI/litellm/issues/5762 * fix(http_handler.py): fix linting error * fix(azure.py): fix logging to show azure_ad_token being used Fixes https://github.com/BerriAI/litellm/issues/5767 * fix(_redis.py): add redis sentinel support Closes https://github.com/BerriAI/litellm/issues/4381 * feat(_redis.py): add redis sentinel support Closes https://github.com/BerriAI/litellm/issues/4381 * test(test_completion_cost.py): fix test * Databricks Integration: Integrate Databricks SDK as optional mechanism for fetching API base and token, if unspecified (#5746) * LiteLLM Minor Fixes & Improvements (09/16/2024) (#5723) * coverage (#5713) Signed-off-by: dbczumar * Move (#5714) Signed-off-by: dbczumar * fix(litellm_logging.py): fix logging client re-init (#5710) Fixes https://github.com/BerriAI/litellm/issues/5695 * fix(presidio.py): Fix logging_hook response and add support for additional presidio variables in guardrails config Fixes https://github.com/BerriAI/litellm/issues/5682 * feat(o1_handler.py): fake streaming for openai o1 models Fixes https://github.com/BerriAI/litellm/issues/5694 * docs: deprecated traceloop integration in favor of native otel (#5249) * fix: fix linting errors * fix: fix linting errors * fix(main.py): fix o1 import --------- Signed-off-by: dbczumar Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com> Co-authored-by: Nir Gazit * feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view (#5730) * feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view Supports having `MonthlyGlobalSpend` view be a material view, and exposes an endpoint to refresh it * fix(custom_logger.py): reset calltype * fix: fix linting errors * fix: fix linting error * fix Signed-off-by: dbczumar * fix: fix import * Fix Signed-off-by: dbczumar * fix Signed-off-by: dbczumar * DB test Signed-off-by: dbczumar * Coverage Signed-off-by: dbczumar * progress Signed-off-by: dbczumar * fix Signed-off-by: dbczumar * fix Signed-off-by: dbczumar * fix Signed-off-by: dbczumar * fix test name Signed-off-by: dbczumar --------- Signed-off-by: dbczumar Co-authored-by: Krish Dholakia Co-authored-by: Nir Gazit * test: fix test * test(test_databricks.py): fix test * fix(databricks/chat.py): handle custom endpoint (e.g. sagemaker) * Apply code scanning fix for clear-text logging of sensitive information Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> * fix(__init__.py): fix known fireworks ai models --------- Signed-off-by: dbczumar Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com> Co-authored-by: Nir Gazit Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> --- docker-compose.yml | 8 + docs/my-website/docs/completion/input.md | 7 +- docs/my-website/docs/proxy/caching.md | 54 +++++++ litellm/__init__.py | 144 +++++++++--------- litellm/_redis.py | 87 +++++++++-- litellm/cost_calculator.py | 5 + litellm/llms/AzureOpenAI/azure.py | 34 +++-- litellm/llms/custom_httpx/http_handler.py | 41 ++++- litellm/llms/databricks/chat.py | 60 ++++++-- .../chat/fireworks_ai_transformation.py} | 2 - litellm/llms/fireworks_ai/cost_calculator.py | 72 +++++++++ .../llms/prompt_templates/image_handling.py | 9 +- ...odel_prices_and_context_window_backup.json | 20 +++ litellm/proxy/_new_secret_config.yaml | 9 +- litellm/proxy/proxy_server.py | 35 +---- litellm/router.py | 35 +++-- litellm/tests/test_caching.py | 54 +++++++ litellm/tests/test_completion_cost.py | 13 ++ litellm/tests/test_prompt_factory.py | 4 + litellm/utils.py | 5 +- model_prices_and_context_window.json | 20 +++ tests/llm_translation/test_databricks.py | 143 ++++++++++++++++- .../test_fireworks_ai_translation.py | 2 +- .../test_max_completion_tokens.py | 4 +- 24 files changed, 697 insertions(+), 170 deletions(-) rename litellm/llms/{fireworks_ai.py => fireworks_ai/chat/fireworks_ai_transformation.py} (99%) create mode 100644 litellm/llms/fireworks_ai/cost_calculator.py diff --git a/docker-compose.yml b/docker-compose.yml index 6991bf7eb..4ae8b717d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -9,6 +9,14 @@ services: ######################################### ## Uncomment these lines to start proxy with a config.yaml file ## # volumes: + # - ./config.yaml:/app/config.yaml <<- this is missing in the docker-compose file currently + # The below two are my suggestion + # command: + # - "--config=/app/config.yaml" + ############################################## + ######################################### + ## Uncomment these lines to start proxy with a config.yaml file ## + # volumes: ############################################### ports: - "4000:4000" # Map the container port to the host, change the host port if necessary diff --git a/docs/my-website/docs/completion/input.md b/docs/my-website/docs/completion/input.md index 0c7c2cd92..c563a5bf0 100644 --- a/docs/my-website/docs/completion/input.md +++ b/docs/my-website/docs/completion/input.md @@ -124,16 +124,19 @@ def completion( #### Properties of `messages` *Note* - Each message in the array contains the following properties: -- `role`: *string* - The role of the message's author. Roles can be: system, user, assistant, or function. +- `role`: *string* - The role of the message's author. Roles can be: system, user, assistant, function or tool. -- `content`: *string or null* - The contents of the message. It is required for all messages, but may be null for assistant messages with function calls. +- `content`: *string or list[dict] or null* - The contents of the message. It is required for all messages, but may be null for assistant messages with function calls. - `name`: *string (optional)* - The name of the author of the message. It is required if the role is "function". The name should match the name of the function represented in the content. It can contain characters (a-z, A-Z, 0-9), and underscores, with a maximum length of 64 characters. - `function_call`: *object (optional)* - The name and arguments of a function that should be called, as generated by the model. +- `tool_call_id`: *str (optional)* - Tool call that this message is responding to. +[**See All Message Values**](https://github.com/BerriAI/litellm/blob/8600ec77042dacad324d3879a2bd918fc6a719fa/litellm/types/llms/openai.py#L392) + ## Optional Fields - `temperature`: *number or null (optional)* - The sampling temperature to be used, between 0 and 2. Higher values like 0.8 produce more random outputs, while lower values like 0.2 make outputs more focused and deterministic. diff --git a/docs/my-website/docs/proxy/caching.md b/docs/my-website/docs/proxy/caching.md index 220ef5c36..4d44a4da0 100644 --- a/docs/my-website/docs/proxy/caching.md +++ b/docs/my-website/docs/proxy/caching.md @@ -110,6 +110,60 @@ print("REDIS_CLUSTER_NODES", os.environ["REDIS_CLUSTER_NODES"]) +#### Redis Sentinel + + + + + + +```yaml +model_list: + - model_name: "*" + litellm_params: + model: "*" + + +litellm_settings: + cache: true + cache_params: + type: "redis" + service_name: "mymaster" + sentinel_nodes: [["localhost", 26379]] +``` + + + + + +You can configure redis sentinel in your .env by setting `REDIS_SENTINEL_NODES` in your .env + +**Example `REDIS_SENTINEL_NODES`** value + +```env +REDIS_SENTINEL_NODES='[["localhost", 26379]]' +REDIS_SERVICE_NAME = "mymaster" +``` + +:::note + +Example python script for setting redis cluster nodes in .env: + +```python +# List of startup nodes +sentinel_nodes = [["localhost", 26379]] + +# set startup nodes in environment variables +os.environ["REDIS_SENTINEL_NODES"] = json.dumps(sentinel_nodes) +print("REDIS_SENTINEL_NODES", os.environ["REDIS_SENTINEL_NODES"]) +``` + +::: + + + + + #### TTL ```yaml diff --git a/litellm/__init__.py b/litellm/__init__.py index db264b016..fb6234e04 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -382,73 +382,81 @@ deepinfra_models: List = [] perplexity_models: List = [] watsonx_models: List = [] gemini_models: List = [] -for key, value in model_cost.items(): - if value.get("litellm_provider") == "openai": - open_ai_chat_completion_models.append(key) - elif value.get("litellm_provider") == "text-completion-openai": - open_ai_text_completion_models.append(key) - elif value.get("litellm_provider") == "cohere": - cohere_models.append(key) - elif value.get("litellm_provider") == "cohere_chat": - cohere_chat_models.append(key) - elif value.get("litellm_provider") == "mistral": - mistral_chat_models.append(key) - elif value.get("litellm_provider") == "anthropic": - anthropic_models.append(key) - elif value.get("litellm_provider") == "empower": - empower_models.append(key) - elif value.get("litellm_provider") == "openrouter": - openrouter_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-text-models": - vertex_text_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-code-text-models": - vertex_code_text_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-language-models": - vertex_language_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-vision-models": - vertex_vision_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-chat-models": - vertex_chat_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-code-chat-models": - vertex_code_chat_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-embedding-models": - vertex_embedding_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-anthropic_models": - key = key.replace("vertex_ai/", "") - vertex_anthropic_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-llama_models": - key = key.replace("vertex_ai/", "") - vertex_llama3_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-mistral_models": - key = key.replace("vertex_ai/", "") - vertex_mistral_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-ai21_models": - key = key.replace("vertex_ai/", "") - vertex_ai_ai21_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-image-models": - key = key.replace("vertex_ai/", "") - vertex_ai_image_models.append(key) - elif value.get("litellm_provider") == "ai21": - if value.get("mode") == "chat": - ai21_chat_models.append(key) - else: - ai21_models.append(key) - elif value.get("litellm_provider") == "nlp_cloud": - nlp_cloud_models.append(key) - elif value.get("litellm_provider") == "aleph_alpha": - aleph_alpha_models.append(key) - elif value.get("litellm_provider") == "bedrock": - bedrock_models.append(key) - elif value.get("litellm_provider") == "deepinfra": - deepinfra_models.append(key) - elif value.get("litellm_provider") == "perplexity": - perplexity_models.append(key) - elif value.get("litellm_provider") == "watsonx": - watsonx_models.append(key) - elif value.get("litellm_provider") == "gemini": - gemini_models.append(key) - elif value.get("litellm_provider") == "fireworks_ai": - fireworks_ai_models.append(key) + + +def add_known_models(): + for key, value in model_cost.items(): + if value.get("litellm_provider") == "openai": + open_ai_chat_completion_models.append(key) + elif value.get("litellm_provider") == "text-completion-openai": + open_ai_text_completion_models.append(key) + elif value.get("litellm_provider") == "cohere": + cohere_models.append(key) + elif value.get("litellm_provider") == "cohere_chat": + cohere_chat_models.append(key) + elif value.get("litellm_provider") == "mistral": + mistral_chat_models.append(key) + elif value.get("litellm_provider") == "anthropic": + anthropic_models.append(key) + elif value.get("litellm_provider") == "empower": + empower_models.append(key) + elif value.get("litellm_provider") == "openrouter": + openrouter_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-text-models": + vertex_text_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-code-text-models": + vertex_code_text_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-language-models": + vertex_language_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-vision-models": + vertex_vision_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-chat-models": + vertex_chat_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-code-chat-models": + vertex_code_chat_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-embedding-models": + vertex_embedding_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-anthropic_models": + key = key.replace("vertex_ai/", "") + vertex_anthropic_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-llama_models": + key = key.replace("vertex_ai/", "") + vertex_llama3_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-mistral_models": + key = key.replace("vertex_ai/", "") + vertex_mistral_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-ai21_models": + key = key.replace("vertex_ai/", "") + vertex_ai_ai21_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-image-models": + key = key.replace("vertex_ai/", "") + vertex_ai_image_models.append(key) + elif value.get("litellm_provider") == "ai21": + if value.get("mode") == "chat": + ai21_chat_models.append(key) + else: + ai21_models.append(key) + elif value.get("litellm_provider") == "nlp_cloud": + nlp_cloud_models.append(key) + elif value.get("litellm_provider") == "aleph_alpha": + aleph_alpha_models.append(key) + elif value.get("litellm_provider") == "bedrock": + bedrock_models.append(key) + elif value.get("litellm_provider") == "deepinfra": + deepinfra_models.append(key) + elif value.get("litellm_provider") == "perplexity": + perplexity_models.append(key) + elif value.get("litellm_provider") == "watsonx": + watsonx_models.append(key) + elif value.get("litellm_provider") == "gemini": + gemini_models.append(key) + elif value.get("litellm_provider") == "fireworks_ai": + # ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params. + if "-to-" not in key: + fireworks_ai_models.append(key) + + +add_known_models() # known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary openai_compatible_endpoints: List = [ "api.perplexity.ai", @@ -960,7 +968,7 @@ from .llms.nvidia_nim import NvidiaNimConfig from .llms.cerebras.chat import CerebrasConfig from .llms.sambanova.chat import SambanovaConfig from .llms.AI21.chat import AI21ChatConfig -from .llms.fireworks_ai import FireworksAIConfig +from .llms.fireworks_ai.chat.fireworks_ai_transformation import FireworksAIConfig from .llms.volcengine import VolcEngineConfig from .llms.text_completion_codestral import MistralTextCompletionConfig from .llms.AzureOpenAI.azure import ( diff --git a/litellm/_redis.py b/litellm/_redis.py index faa98e648..152f7f09e 100644 --- a/litellm/_redis.py +++ b/litellm/_redis.py @@ -12,12 +12,13 @@ import json # s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation import os -from typing import List, Optional +from typing import List, Optional, Union import redis # type: ignore import redis.asyncio as async_redis # type: ignore import litellm +from litellm import get_secret from ._logging import verbose_logger @@ -83,7 +84,7 @@ def _redis_kwargs_from_environment(): return_dict = {} for k, v in mapping.items(): - value = litellm.get_secret(k, default_value=None) # check os.environ/key vault + value = get_secret(k, default_value=None) # type: ignore if value is not None: return_dict[v] = value return return_dict @@ -116,7 +117,7 @@ def _get_redis_client_logic(**env_overrides): for k, v in env_overrides.items(): if isinstance(v, str) and v.startswith("os.environ/"): v = v.replace("os.environ/", "") - value = litellm.get_secret(v) + value = get_secret(v) # type: ignore env_overrides[k] = value redis_kwargs = { @@ -124,13 +125,27 @@ def _get_redis_client_logic(**env_overrides): **env_overrides, } - _startup_nodes = redis_kwargs.get("startup_nodes", None) or litellm.get_secret( + _startup_nodes: Optional[Union[str, list]] = redis_kwargs.get("startup_nodes", None) or get_secret( # type: ignore "REDIS_CLUSTER_NODES" ) - if _startup_nodes is not None: + if _startup_nodes is not None and isinstance(_startup_nodes, str): redis_kwargs["startup_nodes"] = json.loads(_startup_nodes) + _sentinel_nodes: Optional[Union[str, list]] = redis_kwargs.get("sentinel_nodes", None) or get_secret( # type: ignore + "REDIS_SENTINEL_NODES" + ) + + if _sentinel_nodes is not None and isinstance(_sentinel_nodes, str): + redis_kwargs["sentinel_nodes"] = json.loads(_sentinel_nodes) + + _service_name: Optional[str] = redis_kwargs.get("service_name", None) or get_secret( # type: ignore + "REDIS_SERVICE_NAME" + ) + + if _service_name is not None: + redis_kwargs["service_name"] = _service_name + if "url" in redis_kwargs and redis_kwargs["url"] is not None: redis_kwargs.pop("host", None) redis_kwargs.pop("port", None) @@ -138,14 +153,19 @@ def _get_redis_client_logic(**env_overrides): redis_kwargs.pop("password", None) elif "startup_nodes" in redis_kwargs and redis_kwargs["startup_nodes"] is not None: pass + elif ( + "sentinel_nodes" in redis_kwargs and redis_kwargs["sentinel_nodes"] is not None + ): + pass elif "host" not in redis_kwargs or redis_kwargs["host"] is None: raise ValueError("Either 'host' or 'url' must be specified for redis.") + # litellm.print_verbose(f"redis_kwargs: {redis_kwargs}") return redis_kwargs def init_redis_cluster(redis_kwargs) -> redis.RedisCluster: - _redis_cluster_nodes_in_env = litellm.get_secret("REDIS_CLUSTER_NODES") + _redis_cluster_nodes_in_env: Optional[str] = get_secret("REDIS_CLUSTER_NODES") # type: ignore if _redis_cluster_nodes_in_env is not None: try: redis_kwargs["startup_nodes"] = json.loads(_redis_cluster_nodes_in_env) @@ -174,6 +194,44 @@ def init_redis_cluster(redis_kwargs) -> redis.RedisCluster: return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs) +def _init_redis_sentinel(redis_kwargs) -> redis.Redis: + sentinel_nodes = redis_kwargs.get("sentinel_nodes") + service_name = redis_kwargs.get("service_name") + + if not sentinel_nodes or not service_name: + raise ValueError( + "Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel." + ) + + verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.") + + # Set up the Sentinel client + sentinel = redis.Sentinel(sentinel_nodes, socket_timeout=0.1) + + # Return the master instance for the given service + + return sentinel.master_for(service_name) + + +def _init_async_redis_sentinel(redis_kwargs) -> async_redis.Redis: + sentinel_nodes = redis_kwargs.get("sentinel_nodes") + service_name = redis_kwargs.get("service_name") + + if not sentinel_nodes or not service_name: + raise ValueError( + "Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel." + ) + + verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.") + + # Set up the Sentinel client + sentinel = async_redis.Sentinel(sentinel_nodes, socket_timeout=0.1) + + # Return the master instance for the given service + + return sentinel.master_for(service_name) + + def get_redis_client(**env_overrides): redis_kwargs = _get_redis_client_logic(**env_overrides) if "url" in redis_kwargs and redis_kwargs["url"] is not None: @@ -185,12 +243,13 @@ def get_redis_client(**env_overrides): return redis.Redis.from_url(**url_kwargs) - if ( - "startup_nodes" in redis_kwargs - or litellm.get_secret("REDIS_CLUSTER_NODES") is not None - ): + if "startup_nodes" in redis_kwargs or get_secret("REDIS_CLUSTER_NODES") is not None: # type: ignore return init_redis_cluster(redis_kwargs) + # Check for Redis Sentinel + if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs: + return _init_redis_sentinel(redis_kwargs) + return redis.Redis(**redis_kwargs) @@ -203,7 +262,7 @@ def get_redis_async_client(**env_overrides): if arg in args: url_kwargs[arg] = redis_kwargs[arg] else: - litellm.print_verbose( + verbose_logger.debug( "REDIS: ignoring argument: {}. Not an allowed async_redis.Redis.from_url arg.".format( arg ) @@ -225,9 +284,13 @@ def get_redis_async_client(**env_overrides): new_startup_nodes.append(ClusterNode(**item)) redis_kwargs.pop("startup_nodes") return async_redis.RedisCluster( - startup_nodes=new_startup_nodes, **cluster_kwargs + startup_nodes=new_startup_nodes, **cluster_kwargs # type: ignore ) + # Check for Redis Sentinel + if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs: + return _init_async_redis_sentinel(redis_kwargs) + return async_redis.Redis( socket_timeout=5, **redis_kwargs, diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index 0a935a290..a176190d0 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -25,6 +25,9 @@ from litellm.llms.anthropic.cost_calculation import ( from litellm.llms.databricks.cost_calculator import ( cost_per_token as databricks_cost_per_token, ) +from litellm.llms.fireworks_ai.cost_calculator import ( + cost_per_token as fireworks_ai_cost_per_token, +) from litellm.rerank_api.types import RerankResponse from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS @@ -217,6 +220,8 @@ def cost_per_token( return anthropic_cost_per_token(model=model, usage=usage_block) elif custom_llm_provider == "databricks": return databricks_cost_per_token(model=model, usage=usage_block) + elif custom_llm_provider == "fireworks_ai": + return fireworks_ai_cost_per_token(model=model, usage=usage_block) elif custom_llm_provider == "gemini": return google_cost_per_token( model=model_without_prefix, diff --git a/litellm/llms/AzureOpenAI/azure.py b/litellm/llms/AzureOpenAI/azure.py index be3e2cbee..aee070c58 100644 --- a/litellm/llms/AzureOpenAI/azure.py +++ b/litellm/llms/AzureOpenAI/azure.py @@ -13,18 +13,14 @@ from pydantic import BaseModel from typing_extensions import overload import litellm -from litellm import ImageResponse, OpenAIConfig from litellm.caching import DualCache from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.types.utils import FileTypes # type: ignore from litellm.types.utils import EmbeddingResponse from litellm.utils import ( - Choices, CustomStreamWrapper, - Message, ModelResponse, - TranscriptionResponse, UnsupportedParamsError, convert_to_model_response_object, get_secret, @@ -674,7 +670,7 @@ class AzureChatCompletion(BaseLLM): logging_obj=logging_obj, convert_tool_call_to_json_mode=json_mode, ) - elif "stream" in optional_params and optional_params["stream"] == True: + elif "stream" in optional_params and optional_params["stream"] is True: return self.streaming( logging_obj=logging_obj, api_base=api_base, @@ -725,7 +721,11 @@ class AzureChatCompletion(BaseLLM): azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token - if client is None or dynamic_params: + if ( + client is None + or not isinstance(client, AzureOpenAI) + or dynamic_params + ): azure_client = AzureOpenAI(**azure_client_params) else: azure_client = client @@ -824,7 +824,10 @@ class AzureChatCompletion(BaseLLM): input=data["messages"], api_key=azure_client.api_key, additional_args={ - "headers": {"Authorization": f"Bearer {azure_client.api_key}"}, + "headers": { + "api_key": api_key, + "azure_ad_token": azure_ad_token, + }, "api_base": azure_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data, @@ -930,7 +933,10 @@ class AzureChatCompletion(BaseLLM): input=data["messages"], api_key=azure_client.api_key, additional_args={ - "headers": {"Authorization": f"Bearer {azure_client.api_key}"}, + "headers": { + "api_key": api_key, + "azure_ad_token": azure_ad_token, + }, "api_base": azure_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data, @@ -988,7 +994,10 @@ class AzureChatCompletion(BaseLLM): input=data["messages"], api_key=azure_client.api_key, additional_args={ - "headers": {"Authorization": f"Bearer {azure_client.api_key}"}, + "headers": { + "api_key": api_key, + "azure_ad_token": azure_ad_token, + }, "api_base": azure_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data, @@ -1567,12 +1576,11 @@ class AzureChatCompletion(BaseLLM): # return response return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore except AzureOpenAIError as e: - exception_mapping_worked = True raise e except Exception as e: - if hasattr(e, "status_code"): - _status_code = getattr(e, "status_code") - raise AzureOpenAIError(status_code=_status_code, message=str(e)) + error_code = getattr(e, "status_code", None) + if error_code is not None: + raise AzureOpenAIError(status_code=error_code, message=str(e)) else: raise AzureOpenAIError(status_code=500, message=str(e)) diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index e63bd3f54..58fdf065e 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -4,6 +4,7 @@ import traceback from typing import TYPE_CHECKING, Any, Mapping, Optional, Union import httpx +from httpx import USE_CLIENT_DEFAULT import litellm @@ -76,9 +77,20 @@ class AsyncHTTPHandler: await self.client.aclose() async def get( - self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None + self, + url: str, + params: Optional[dict] = None, + headers: Optional[dict] = None, + follow_redirects: Optional[bool] = None, ): - response = await self.client.get(url, params=params, headers=headers) + # Set follow_redirects to UseClientDefault if None + _follow_redirects = ( + follow_redirects if follow_redirects is not None else USE_CLIENT_DEFAULT + ) + + response = await self.client.get( + url, params=params, headers=headers, follow_redirects=_follow_redirects # type: ignore + ) return response async def post( @@ -117,8 +129,9 @@ class AsyncHTTPHandler: await new_client.aclose() except httpx.TimeoutException as e: headers = {} - if hasattr(e, "response") and e.response is not None: - for key, value in e.response.headers.items(): + error_response = getattr(e, "response", None) + if error_response is not None: + for key, value in error_response.headers.items(): headers["response_headers-{}".format(key)] = value raise litellm.Timeout( @@ -173,8 +186,9 @@ class AsyncHTTPHandler: await new_client.aclose() except httpx.TimeoutException as e: headers = {} - if hasattr(e, "response") and e.response is not None: - for key, value in e.response.headers.items(): + error_response = getattr(e, "response", None) + if error_response is not None: + for key, value in error_response.headers.items(): headers["response_headers-{}".format(key)] = value raise litellm.Timeout( @@ -303,9 +317,20 @@ class HTTPHandler: self.client.close() def get( - self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None + self, + url: str, + params: Optional[dict] = None, + headers: Optional[dict] = None, + follow_redirects: Optional[bool] = None, ): - response = self.client.get(url, params=params, headers=headers) + # Set follow_redirects to UseClientDefault if None + _follow_redirects = ( + follow_redirects if follow_redirects is not None else USE_CLIENT_DEFAULT + ) + + response = self.client.get( + url, params=params, headers=headers, follow_redirects=_follow_redirects # type: ignore + ) return response def post( diff --git a/litellm/llms/databricks/chat.py b/litellm/llms/databricks/chat.py index 343cdd3ff..23c780746 100644 --- a/litellm/llms/databricks/chat.py +++ b/litellm/llms/databricks/chat.py @@ -244,6 +244,34 @@ class DatabricksChatCompletion(BaseLLM): # makes headers for API call + def _get_databricks_credentials( + self, api_key: Optional[str], api_base: Optional[str], headers: Optional[dict] + ) -> Tuple[str, dict]: + headers = headers or {"Content-Type": "application/json"} + try: + from databricks.sdk import WorkspaceClient + + databricks_client = WorkspaceClient() + api_base = api_base or f"{databricks_client.config.host}/serving-endpoints" + + if api_key is None: + databricks_auth_headers: dict[str, str] = ( + databricks_client.config.authenticate() + ) + headers = {**databricks_auth_headers, **headers} + + return api_base, headers + except ImportError: + raise DatabricksError( + status_code=400, + message=( + "If the Databricks base URL and API key are not set, the databricks-sdk " + "Python library must be installed. Please install the databricks-sdk, set " + "{LLM_PROVIDER}_API_BASE and {LLM_PROVIDER}_API_KEY environment variables, " + "or provide the base URL and API key as arguments." + ), + ) + def _validate_environment( self, api_key: Optional[str], @@ -253,16 +281,26 @@ class DatabricksChatCompletion(BaseLLM): headers: Optional[dict], ) -> Tuple[str, dict]: if api_key is None and headers is None: - raise DatabricksError( - status_code=400, - message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", - ) + if custom_endpoint: + raise DatabricksError( + status_code=400, + message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", + ) + else: + api_base, headers = self._get_databricks_credentials( + api_base=api_base, api_key=api_key, headers=headers + ) if api_base is None: - raise DatabricksError( - status_code=400, - message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", - ) + if custom_endpoint: + raise DatabricksError( + status_code=400, + message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", + ) + else: + api_base, headers = self._get_databricks_credentials( + api_base=api_base, api_key=api_key, headers=headers + ) if headers is None: headers = { @@ -273,6 +311,9 @@ class DatabricksChatCompletion(BaseLLM): if api_key is not None: headers.update({"Authorization": "Bearer {}".format(api_key)}) + if api_key is not None: + headers["Authorization"] = f"Bearer {api_key}" + if endpoint_type == "chat_completions" and custom_endpoint is not True: api_base = "{}/chat/completions".format(api_base) elif endpoint_type == "embeddings" and custom_endpoint is not True: @@ -520,7 +561,8 @@ class DatabricksChatCompletion(BaseLLM): response_json = response.json() except httpx.HTTPStatusError as e: raise DatabricksError( - status_code=e.response.status_code, message=e.response.text + status_code=e.response.status_code, + message=e.response.text, ) except httpx.TimeoutException as e: raise DatabricksError( diff --git a/litellm/llms/fireworks_ai.py b/litellm/llms/fireworks_ai/chat/fireworks_ai_transformation.py similarity index 99% rename from litellm/llms/fireworks_ai.py rename to litellm/llms/fireworks_ai/chat/fireworks_ai_transformation.py index b6511689e..661c8352f 100644 --- a/litellm/llms/fireworks_ai.py +++ b/litellm/llms/fireworks_ai/chat/fireworks_ai_transformation.py @@ -1,8 +1,6 @@ import types from typing import Literal, Optional, Union -import litellm - class FireworksAIConfig: """ diff --git a/litellm/llms/fireworks_ai/cost_calculator.py b/litellm/llms/fireworks_ai/cost_calculator.py new file mode 100644 index 000000000..83ce97d4c --- /dev/null +++ b/litellm/llms/fireworks_ai/cost_calculator.py @@ -0,0 +1,72 @@ +""" +For calculating cost of fireworks ai serverless inference models. +""" + +from typing import Tuple + +from litellm.types.utils import Usage +from litellm.utils import get_model_info + + +# Extract the number of billion parameters from the model name +# only used for together_computer LLMs +def get_model_params_and_category(model_name: str) -> str: + """ + Helper function for calculating together ai pricing. + + Returns: + - str: model pricing category if mapped else received model name + """ + import re + + model_name = model_name.lower() + + # Check for MoE models in the form xb + moe_match = re.search(r"(\d+)x(\d+)b", model_name) + if moe_match: + total_billion = int(moe_match.group(1)) * int(moe_match.group(2)) + if total_billion <= 56: + return "fireworks-ai-moe-up-to-56b" + elif total_billion <= 176: + return "fireworks-ai-56b-to-176b" + + # Check for standard models in the form b + re_params_match = re.search(r"(\d+)b", model_name) + if re_params_match is not None: + params_match = str(re_params_match.group(1)) + params_billion = float(params_match) + + # Determine the category based on the number of parameters + if params_billion <= 16.0: + return "fireworks-ai-up-to-16b" + elif params_billion <= 80.0: + return "fireworks-ai-16b-80b" + + # If no matches, return the original model_name + return model_name + + +def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]: + """ + Calculates the cost per token for a given model, prompt tokens, and completion tokens. + + Input: + - model: str, the model name without provider prefix + - usage: LiteLLM Usage block, containing anthropic caching information + + Returns: + Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd + """ + base_model = get_model_params_and_category(model_name=model) + + ## GET MODEL INFO + model_info = get_model_info(model=base_model, custom_llm_provider="fireworks_ai") + + ## CALCULATE INPUT COST + + prompt_cost: float = usage["prompt_tokens"] * model_info["input_cost_per_token"] + + ## CALCULATE OUTPUT COST + completion_cost = usage["completion_tokens"] * model_info["output_cost_per_token"] + + return prompt_cost, completion_cost diff --git a/litellm/llms/prompt_templates/image_handling.py b/litellm/llms/prompt_templates/image_handling.py index 90db3dedc..c91c05df9 100644 --- a/litellm/llms/prompt_templates/image_handling.py +++ b/litellm/llms/prompt_templates/image_handling.py @@ -7,6 +7,7 @@ import base64 from httpx import Response import litellm +from litellm import verbose_logger from litellm.caching import InMemoryCache from litellm.llms.custom_httpx.http_handler import ( _get_httpx_client, @@ -58,7 +59,7 @@ async def async_convert_url_to_base64(url: str) -> str: client = litellm.module_level_aclient for _ in range(3): try: - response = await client.get(url) + response = await client.get(url, follow_redirects=True) return _process_image_response(response, url) except: pass @@ -75,9 +76,11 @@ def convert_url_to_base64(url: str) -> str: client = litellm.module_level_client for _ in range(3): try: - response = client.get(url) + response = client.get(url, follow_redirects=True) return _process_image_response(response, url) - except: + except Exception as e: + verbose_logger.exception(e) + # print(e) pass raise Exception( f"Error: Unable to fetch image from URL after 3 attempts. url={url}" diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index dbd9bd73b..23389d530 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -5452,6 +5452,26 @@ "mode": "chat", "supports_function_calling": true, "source": "https://fireworks.ai/pricing" + }, + "fireworks-ai-up-to-16b": { + "input_cost_per_token": 0.0000002, + "output_cost_per_token": 0.0000002, + "litellm_provider": "fireworks_ai" + }, + "fireworks-ai-16.1b-to-80b": { + "input_cost_per_token": 0.0000009, + "output_cost_per_token": 0.0000009, + "litellm_provider": "fireworks_ai" + }, + "fireworks-ai-moe-up-to-56b": { + "input_cost_per_token": 0.0000005, + "output_cost_per_token": 0.0000005, + "litellm_provider": "fireworks_ai" + }, + "fireworks-ai-56b-to-176b": { + "input_cost_per_token": 0.0000012, + "output_cost_per_token": 0.0000012, + "litellm_provider": "fireworks_ai" }, "anyscale/mistralai/Mistral-7B-Instruct-v0.1": { "max_tokens": 16384, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index a2911dca6..5773f9f51 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -19,4 +19,11 @@ model_list: - model_name: o1-preview litellm_params: model: o1-preview - \ No newline at end of file + +litellm_settings: + cache: true + # cache_params: + # type: "redis" + # service_name: "mymaster" + # sentinel_nodes: + # - ["localhost", 26379] \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 20dab118b..613dacb57 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -653,38 +653,19 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False): return try: - from azure.identity import ClientSecretCredential, DefaultAzureCredential + from azure.identity import DefaultAzureCredential from azure.keyvault.secrets import SecretClient # Set your Azure Key Vault URI KVUri = os.getenv("AZURE_KEY_VAULT_URI", None) - # Set your Azure AD application/client ID, client secret, and tenant ID - client_id = os.getenv("AZURE_CLIENT_ID", None) - client_secret = os.getenv("AZURE_CLIENT_SECRET", None) - tenant_id = os.getenv("AZURE_TENANT_ID", None) + credential = DefaultAzureCredential() - if ( - KVUri is not None - and client_id is not None - and client_secret is not None - and tenant_id is not None - ): - # Initialize the ClientSecretCredential - # credential = ClientSecretCredential( - # client_id=client_id, client_secret=client_secret, tenant_id=tenant_id - # ) - credential = DefaultAzureCredential() + # Create the SecretClient using the credential + client = SecretClient(vault_url=KVUri, credential=credential) - # Create the SecretClient using the credential - client = SecretClient(vault_url=KVUri, credential=credential) - - litellm.secret_manager_client = client - litellm._key_management_system = KeyManagementSystem.AZURE_KEY_VAULT - else: - raise Exception( - f"Missing KVUri or client_id or client_secret or tenant_id from environment" - ) + litellm.secret_manager_client = client + litellm._key_management_system = KeyManagementSystem.AZURE_KEY_VAULT except Exception as e: _error_str = str(e) verbose_proxy_logger.exception( @@ -1626,8 +1607,8 @@ class ProxyConfig: ## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = `, _redis.py checks for REDIS specific environment variables self._init_cache(cache_params=cache_params) if litellm.cache is not None: - verbose_proxy_logger.debug( # noqa - f"{blue_color_code}Set Cache on LiteLLM Proxy= {vars(litellm.cache.cache)}{vars(litellm.cache)}{reset_color_code}" + verbose_proxy_logger.debug( + f"{blue_color_code}Set Cache on LiteLLM Proxy{reset_color_code}" ) elif key == "cache" and value is False: pass diff --git a/litellm/router.py b/litellm/router.py index 2a3f583fa..780eeb3e7 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -4019,7 +4019,9 @@ class Router: _model_info=_model_info, ) - verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}") + verbose_router_logger.debug( + f"\nInitialized Model List {self.get_model_names()}" + ) self.model_names = [m["model_name"] for m in model_list] def _add_deployment(self, deployment: Deployment) -> Deployment: @@ -4630,24 +4632,25 @@ class Router: if hasattr(self, "model_list"): returned_models: List[DeploymentTypedDict] = [] - for model_alias, model_value in self.model_group_alias.items(): + if hasattr(self, "model_group_alias"): + for model_alias, model_value in self.model_group_alias.items(): - if isinstance(model_value, str): - _router_model_name: str = model_value - elif isinstance(model_value, dict): - _model_value = RouterModelGroupAliasItem(**model_value) # type: ignore - if _model_value["hidden"] is True: - continue + if isinstance(model_value, str): + _router_model_name: str = model_value + elif isinstance(model_value, dict): + _model_value = RouterModelGroupAliasItem(**model_value) # type: ignore + if _model_value["hidden"] is True: + continue + else: + _router_model_name = _model_value["model"] else: - _router_model_name = _model_value["model"] - else: - continue + continue - returned_models.extend( - self._get_all_deployments( - model_name=_router_model_name, model_alias=model_alias + returned_models.extend( + self._get_all_deployments( + model_name=_router_model_name, model_alias=model_alias + ) ) - ) if model_name is None: returned_models += self.model_list @@ -5030,7 +5033,7 @@ class Router: # return the first deployment where the `model` matches the specificed deployment name return deployment_model, deployment raise ValueError( - f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}" + f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.get_model_names()}" ) elif model in self.get_model_ids(): deployment = self.get_model_info(id=model) diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index a52f17847..5244fa4b7 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -2045,3 +2045,57 @@ async def test_proxy_logging_setup(): pl_obj = ProxyLogging(user_api_key_cache=DualCache()) assert pl_obj.internal_usage_cache.always_read_redis is True + + +@pytest.mark.skip(reason="local test. Requires sentinel setup.") +@pytest.mark.asyncio +async def test_redis_sentinel_caching(): + """ + Init redis client + - write to client + - read from client + """ + litellm.set_verbose = False + + random_number = random.randint( + 1, 100000 + ) # add a random number to ensure it's always adding / reading from cache + messages = [ + {"role": "user", "content": f"write a one sentence poem about: {random_number}"} + ] + + litellm.cache = Cache( + type="redis", + # host=os.environ["REDIS_HOST"], + # port=os.environ["REDIS_PORT"], + # password=os.environ["REDIS_PASSWORD"], + service_name="mymaster", + sentinel_nodes=[("localhost", 26379)], + ) + response1 = completion( + model="gpt-3.5-turbo", + messages=messages, + ) + + cache_key = litellm.cache.get_cache_key( + model="gpt-3.5-turbo", + messages=messages, + ) + print(f"cache_key: {cache_key}") + litellm.cache.add_cache(result=response1, cache_key=cache_key) + print(f"cache key pre async get: {cache_key}") + stored_val = litellm.cache.get_cache( + model="gpt-3.5-turbo", + messages=messages, + ) + + print(f"stored_val: {stored_val}") + assert stored_val["id"] == response1.id + + stored_val_2 = await litellm.cache.async_get_cache( + model="gpt-3.5-turbo", + messages=messages, + ) + + print(f"stored_val: {stored_val}") + assert stored_val_2["id"] == response1.id diff --git a/litellm/tests/test_completion_cost.py b/litellm/tests/test_completion_cost.py index 835cced45..82257fad8 100644 --- a/litellm/tests/test_completion_cost.py +++ b/litellm/tests/test_completion_cost.py @@ -1255,3 +1255,16 @@ def test_completion_cost_databricks_embedding(model): print(resp) cost = completion_cost(completion_response=resp) + + +def test_completion_cost_fireworks_ai(): + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + litellm.model_cost = litellm.get_model_cost_map(url="") + + messages = [{"role": "user", "content": "Hey, how's it going?"}] + resp = litellm.completion( + model="fireworks_ai/mixtral-8x7b-instruct", messages=messages + ) # works fine + + print(resp) + cost = completion_cost(completion_response=resp) diff --git a/litellm/tests/test_prompt_factory.py b/litellm/tests/test_prompt_factory.py index 4c99efb3e..5f285e00a 100644 --- a/litellm/tests/test_prompt_factory.py +++ b/litellm/tests/test_prompt_factory.py @@ -432,3 +432,7 @@ def test_vertex_only_image_user_message(): ), "Invalid gemini input. Got={}, Expected={}".format( content, expected_response[idx] ) + + +def test_convert_url(): + convert_url_to_base64("https://picsum.photos/id/237/200/300") diff --git a/litellm/utils.py b/litellm/utils.py index 7c5ef4248..7d6d5223c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6830,7 +6830,10 @@ def exception_type( llm_provider=custom_llm_provider, model=model, ) - elif original_exception.status_code == 401: + elif ( + original_exception.status_code == 401 + or original_exception.status_code == 403 + ): exception_mapping_worked = True raise AuthenticationError( message=f"{custom_llm_provider}Exception - {original_exception.message}", diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index dbd9bd73b..23389d530 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -5452,6 +5452,26 @@ "mode": "chat", "supports_function_calling": true, "source": "https://fireworks.ai/pricing" + }, + "fireworks-ai-up-to-16b": { + "input_cost_per_token": 0.0000002, + "output_cost_per_token": 0.0000002, + "litellm_provider": "fireworks_ai" + }, + "fireworks-ai-16.1b-to-80b": { + "input_cost_per_token": 0.0000009, + "output_cost_per_token": 0.0000009, + "litellm_provider": "fireworks_ai" + }, + "fireworks-ai-moe-up-to-56b": { + "input_cost_per_token": 0.0000005, + "output_cost_per_token": 0.0000005, + "litellm_provider": "fireworks_ai" + }, + "fireworks-ai-56b-to-176b": { + "input_cost_per_token": 0.0000012, + "output_cost_per_token": 0.0000012, + "litellm_provider": "fireworks_ai" }, "anyscale/mistralai/Mistral-7B-Instruct-v0.1": { "max_tokens": 16384, diff --git a/tests/llm_translation/test_databricks.py b/tests/llm_translation/test_databricks.py index 067b188ed..b3bd92d8d 100644 --- a/tests/llm_translation/test_databricks.py +++ b/tests/llm_translation/test_databricks.py @@ -7,10 +7,17 @@ from typing import Any, Dict, List from unittest.mock import MagicMock, Mock, patch import litellm -from litellm.exceptions import BadRequestError, InternalServerError +from litellm.exceptions import BadRequestError from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.utils import CustomStreamWrapper +try: + import databricks.sdk + + databricks_sdk_installed = True +except ImportError: + databricks_sdk_installed = False + def mock_chat_response() -> Dict[str, Any]: return { @@ -33,8 +40,8 @@ def mock_chat_response() -> Dict[str, Any]: "usage": { "prompt_tokens": 230, "completion_tokens": 38, - "total_tokens": 268, "completion_tokens_details": None, + "total_tokens": 268, }, "system_fingerprint": None, } @@ -195,7 +202,14 @@ def mock_embedding_response() -> Dict[str, Any]: @pytest.mark.parametrize("set_base", [True, False]) -def test_throws_if_only_one_of_api_base_or_api_key_set(monkeypatch, set_base): +def test_throws_if_api_base_or_api_key_not_set_without_databricks_sdk( + monkeypatch, set_base +): + # Simulate that the databricks SDK is not installed + monkeypatch.setitem(sys.modules, "databricks.sdk", None) + + err_msg = "the Databricks base URL and API key are not set" + if set_base: monkeypatch.setenv( "DATABRICKS_API_BASE", @@ -204,11 +218,11 @@ def test_throws_if_only_one_of_api_base_or_api_key_set(monkeypatch, set_base): monkeypatch.delenv( "DATABRICKS_API_KEY", ) - err_msg = "A call is being made to LLM Provider but no key is set" else: monkeypatch.setenv("DATABRICKS_API_KEY", "dapimykey") - monkeypatch.delenv("DATABRICKS_API_BASE") - err_msg = "A call is being made to LLM Provider but no api base is set" + monkeypatch.delenv( + "DATABRICKS_API_BASE", + ) with pytest.raises(BadRequestError) as exc: litellm.completion( @@ -422,6 +436,67 @@ def test_completions_streaming_with_async_http_handler(monkeypatch): ) +@pytest.mark.skipif(not databricks_sdk_installed, reason="Databricks SDK not installed") +def test_completions_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch): + from databricks.sdk import WorkspaceClient + from databricks.sdk.config import Config + + sync_handler = HTTPHandler() + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = mock_chat_response() + + expected_response_json = { + **mock_chat_response(), + **{ + "model": "databricks/dbrx-instruct-071224", + }, + } + + base_url = "https://my.workspace.cloud.databricks.com" + api_key = "dapimykey" + headers = { + "Authorization": f"Bearer {api_key}", + } + messages = [{"role": "user", "content": "How are you?"}] + + mock_workspace_client: WorkspaceClient = MagicMock() + mock_config: Config = MagicMock() + # Simulate the behavior of the config property and its methods + mock_config.authenticate.side_effect = lambda: headers + mock_config.host = base_url # Assign directly as if it's a property + mock_workspace_client.config = mock_config + + with patch( + "databricks.sdk.WorkspaceClient", return_value=mock_workspace_client + ), patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post: + response = litellm.completion( + model="databricks/dbrx-instruct-071224", + messages=messages, + client=sync_handler, + temperature=0.5, + extraparam="testpassingextraparam", + ) + assert response.to_dict() == expected_response_json + + mock_post.assert_called_once_with( + f"{base_url}/serving-endpoints/chat/completions", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + data=json.dumps( + { + "model": "dbrx-instruct-071224", + "messages": messages, + "temperature": 0.5, + "extraparam": "testpassingextraparam", + "stream": False, + } + ), + ) + + def test_embeddings_with_sync_http_handler(monkeypatch): base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints" api_key = "dapimykey" @@ -500,3 +575,59 @@ def test_embeddings_with_async_http_handler(monkeypatch): } ), ) + + +@pytest.mark.skipif(not databricks_sdk_installed, reason="Databricks SDK not installed") +def test_embeddings_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch): + from databricks.sdk import WorkspaceClient + from databricks.sdk.config import Config + + base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints" + api_key = "dapimykey" + monkeypatch.setenv("DATABRICKS_API_BASE", base_url) + monkeypatch.setenv("DATABRICKS_API_KEY", api_key) + + sync_handler = HTTPHandler() + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = mock_embedding_response() + + base_url = "https://my.workspace.cloud.databricks.com" + api_key = "dapimykey" + headers = { + "Authorization": f"Bearer {api_key}", + } + inputs = ["Hello", "World"] + + mock_workspace_client: WorkspaceClient = MagicMock() + mock_config: Config = MagicMock() + # Simulate the behavior of the config property and its methods + mock_config.authenticate.side_effect = lambda: headers + mock_config.host = base_url # Assign directly as if it's a property + mock_workspace_client.config = mock_config + + with patch( + "databricks.sdk.WorkspaceClient", return_value=mock_workspace_client + ), patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post: + response = litellm.embedding( + model="databricks/bge-large-en-v1.5", + input=inputs, + client=sync_handler, + extraparam="testpassingextraparam", + ) + assert response.to_dict() == mock_embedding_response() + + mock_post.assert_called_once_with( + f"{base_url}/serving-endpoints/embeddings", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + data=json.dumps( + { + "model": "bge-large-en-v1.5", + "input": inputs, + "extraparam": "testpassingextraparam", + } + ), + ) diff --git a/tests/llm_translation/test_fireworks_ai_translation.py b/tests/llm_translation/test_fireworks_ai_translation.py index c7c1f5445..00361cd18 100644 --- a/tests/llm_translation/test_fireworks_ai_translation.py +++ b/tests/llm_translation/test_fireworks_ai_translation.py @@ -7,7 +7,7 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -from litellm.llms.fireworks_ai import FireworksAIConfig +from litellm.llms.fireworks_ai.chat.fireworks_ai_transformation import FireworksAIConfig fireworks = FireworksAIConfig() diff --git a/tests/llm_translation/test_max_completion_tokens.py b/tests/llm_translation/test_max_completion_tokens.py index 53f9fc761..6d5eb8e3c 100644 --- a/tests/llm_translation/test_max_completion_tokens.py +++ b/tests/llm_translation/test_max_completion_tokens.py @@ -149,7 +149,9 @@ def test_all_model_configs(): {"max_completion_tokens": 10}, {}, "llama3" ) == {"max_tokens": 10} - from litellm.llms.fireworks_ai import FireworksAIConfig + from litellm.llms.fireworks_ai.chat.fireworks_ai_transformation import ( + FireworksAIConfig, + ) assert "max_completion_tokens" in FireworksAIConfig().get_supported_openai_params() assert FireworksAIConfig().map_openai_params(