diff --git a/docker-compose.yml b/docker-compose.yml
index 6991bf7eb..4ae8b717d 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -9,6 +9,14 @@ services:
#########################################
## Uncomment these lines to start proxy with a config.yaml file ##
# volumes:
+ # - ./config.yaml:/app/config.yaml <<- this is missing in the docker-compose file currently
+ # The below two are my suggestion
+ # command:
+ # - "--config=/app/config.yaml"
+ ##############################################
+ #########################################
+ ## Uncomment these lines to start proxy with a config.yaml file ##
+ # volumes:
###############################################
ports:
- "4000:4000" # Map the container port to the host, change the host port if necessary
diff --git a/docs/my-website/docs/completion/input.md b/docs/my-website/docs/completion/input.md
index 0c7c2cd92..c563a5bf0 100644
--- a/docs/my-website/docs/completion/input.md
+++ b/docs/my-website/docs/completion/input.md
@@ -124,16 +124,19 @@ def completion(
#### Properties of `messages`
*Note* - Each message in the array contains the following properties:
-- `role`: *string* - The role of the message's author. Roles can be: system, user, assistant, or function.
+- `role`: *string* - The role of the message's author. Roles can be: system, user, assistant, function or tool.
-- `content`: *string or null* - The contents of the message. It is required for all messages, but may be null for assistant messages with function calls.
+- `content`: *string or list[dict] or null* - The contents of the message. It is required for all messages, but may be null for assistant messages with function calls.
- `name`: *string (optional)* - The name of the author of the message. It is required if the role is "function". The name should match the name of the function represented in the content. It can contain characters (a-z, A-Z, 0-9), and underscores, with a maximum length of 64 characters.
- `function_call`: *object (optional)* - The name and arguments of a function that should be called, as generated by the model.
+- `tool_call_id`: *str (optional)* - Tool call that this message is responding to.
+[**See All Message Values**](https://github.com/BerriAI/litellm/blob/8600ec77042dacad324d3879a2bd918fc6a719fa/litellm/types/llms/openai.py#L392)
+
## Optional Fields
- `temperature`: *number or null (optional)* - The sampling temperature to be used, between 0 and 2. Higher values like 0.8 produce more random outputs, while lower values like 0.2 make outputs more focused and deterministic.
diff --git a/docs/my-website/docs/proxy/caching.md b/docs/my-website/docs/proxy/caching.md
index 220ef5c36..4d44a4da0 100644
--- a/docs/my-website/docs/proxy/caching.md
+++ b/docs/my-website/docs/proxy/caching.md
@@ -110,6 +110,60 @@ print("REDIS_CLUSTER_NODES", os.environ["REDIS_CLUSTER_NODES"])
+#### Redis Sentinel
+
+
+
+
+
+
+```yaml
+model_list:
+ - model_name: "*"
+ litellm_params:
+ model: "*"
+
+
+litellm_settings:
+ cache: true
+ cache_params:
+ type: "redis"
+ service_name: "mymaster"
+ sentinel_nodes: [["localhost", 26379]]
+```
+
+
+
+
+
+You can configure redis sentinel in your .env by setting `REDIS_SENTINEL_NODES` in your .env
+
+**Example `REDIS_SENTINEL_NODES`** value
+
+```env
+REDIS_SENTINEL_NODES='[["localhost", 26379]]'
+REDIS_SERVICE_NAME = "mymaster"
+```
+
+:::note
+
+Example python script for setting redis cluster nodes in .env:
+
+```python
+# List of startup nodes
+sentinel_nodes = [["localhost", 26379]]
+
+# set startup nodes in environment variables
+os.environ["REDIS_SENTINEL_NODES"] = json.dumps(sentinel_nodes)
+print("REDIS_SENTINEL_NODES", os.environ["REDIS_SENTINEL_NODES"])
+```
+
+:::
+
+
+
+
+
#### TTL
```yaml
diff --git a/litellm/__init__.py b/litellm/__init__.py
index db264b016..fb6234e04 100644
--- a/litellm/__init__.py
+++ b/litellm/__init__.py
@@ -382,73 +382,81 @@ deepinfra_models: List = []
perplexity_models: List = []
watsonx_models: List = []
gemini_models: List = []
-for key, value in model_cost.items():
- if value.get("litellm_provider") == "openai":
- open_ai_chat_completion_models.append(key)
- elif value.get("litellm_provider") == "text-completion-openai":
- open_ai_text_completion_models.append(key)
- elif value.get("litellm_provider") == "cohere":
- cohere_models.append(key)
- elif value.get("litellm_provider") == "cohere_chat":
- cohere_chat_models.append(key)
- elif value.get("litellm_provider") == "mistral":
- mistral_chat_models.append(key)
- elif value.get("litellm_provider") == "anthropic":
- anthropic_models.append(key)
- elif value.get("litellm_provider") == "empower":
- empower_models.append(key)
- elif value.get("litellm_provider") == "openrouter":
- openrouter_models.append(key)
- elif value.get("litellm_provider") == "vertex_ai-text-models":
- vertex_text_models.append(key)
- elif value.get("litellm_provider") == "vertex_ai-code-text-models":
- vertex_code_text_models.append(key)
- elif value.get("litellm_provider") == "vertex_ai-language-models":
- vertex_language_models.append(key)
- elif value.get("litellm_provider") == "vertex_ai-vision-models":
- vertex_vision_models.append(key)
- elif value.get("litellm_provider") == "vertex_ai-chat-models":
- vertex_chat_models.append(key)
- elif value.get("litellm_provider") == "vertex_ai-code-chat-models":
- vertex_code_chat_models.append(key)
- elif value.get("litellm_provider") == "vertex_ai-embedding-models":
- vertex_embedding_models.append(key)
- elif value.get("litellm_provider") == "vertex_ai-anthropic_models":
- key = key.replace("vertex_ai/", "")
- vertex_anthropic_models.append(key)
- elif value.get("litellm_provider") == "vertex_ai-llama_models":
- key = key.replace("vertex_ai/", "")
- vertex_llama3_models.append(key)
- elif value.get("litellm_provider") == "vertex_ai-mistral_models":
- key = key.replace("vertex_ai/", "")
- vertex_mistral_models.append(key)
- elif value.get("litellm_provider") == "vertex_ai-ai21_models":
- key = key.replace("vertex_ai/", "")
- vertex_ai_ai21_models.append(key)
- elif value.get("litellm_provider") == "vertex_ai-image-models":
- key = key.replace("vertex_ai/", "")
- vertex_ai_image_models.append(key)
- elif value.get("litellm_provider") == "ai21":
- if value.get("mode") == "chat":
- ai21_chat_models.append(key)
- else:
- ai21_models.append(key)
- elif value.get("litellm_provider") == "nlp_cloud":
- nlp_cloud_models.append(key)
- elif value.get("litellm_provider") == "aleph_alpha":
- aleph_alpha_models.append(key)
- elif value.get("litellm_provider") == "bedrock":
- bedrock_models.append(key)
- elif value.get("litellm_provider") == "deepinfra":
- deepinfra_models.append(key)
- elif value.get("litellm_provider") == "perplexity":
- perplexity_models.append(key)
- elif value.get("litellm_provider") == "watsonx":
- watsonx_models.append(key)
- elif value.get("litellm_provider") == "gemini":
- gemini_models.append(key)
- elif value.get("litellm_provider") == "fireworks_ai":
- fireworks_ai_models.append(key)
+
+
+def add_known_models():
+ for key, value in model_cost.items():
+ if value.get("litellm_provider") == "openai":
+ open_ai_chat_completion_models.append(key)
+ elif value.get("litellm_provider") == "text-completion-openai":
+ open_ai_text_completion_models.append(key)
+ elif value.get("litellm_provider") == "cohere":
+ cohere_models.append(key)
+ elif value.get("litellm_provider") == "cohere_chat":
+ cohere_chat_models.append(key)
+ elif value.get("litellm_provider") == "mistral":
+ mistral_chat_models.append(key)
+ elif value.get("litellm_provider") == "anthropic":
+ anthropic_models.append(key)
+ elif value.get("litellm_provider") == "empower":
+ empower_models.append(key)
+ elif value.get("litellm_provider") == "openrouter":
+ openrouter_models.append(key)
+ elif value.get("litellm_provider") == "vertex_ai-text-models":
+ vertex_text_models.append(key)
+ elif value.get("litellm_provider") == "vertex_ai-code-text-models":
+ vertex_code_text_models.append(key)
+ elif value.get("litellm_provider") == "vertex_ai-language-models":
+ vertex_language_models.append(key)
+ elif value.get("litellm_provider") == "vertex_ai-vision-models":
+ vertex_vision_models.append(key)
+ elif value.get("litellm_provider") == "vertex_ai-chat-models":
+ vertex_chat_models.append(key)
+ elif value.get("litellm_provider") == "vertex_ai-code-chat-models":
+ vertex_code_chat_models.append(key)
+ elif value.get("litellm_provider") == "vertex_ai-embedding-models":
+ vertex_embedding_models.append(key)
+ elif value.get("litellm_provider") == "vertex_ai-anthropic_models":
+ key = key.replace("vertex_ai/", "")
+ vertex_anthropic_models.append(key)
+ elif value.get("litellm_provider") == "vertex_ai-llama_models":
+ key = key.replace("vertex_ai/", "")
+ vertex_llama3_models.append(key)
+ elif value.get("litellm_provider") == "vertex_ai-mistral_models":
+ key = key.replace("vertex_ai/", "")
+ vertex_mistral_models.append(key)
+ elif value.get("litellm_provider") == "vertex_ai-ai21_models":
+ key = key.replace("vertex_ai/", "")
+ vertex_ai_ai21_models.append(key)
+ elif value.get("litellm_provider") == "vertex_ai-image-models":
+ key = key.replace("vertex_ai/", "")
+ vertex_ai_image_models.append(key)
+ elif value.get("litellm_provider") == "ai21":
+ if value.get("mode") == "chat":
+ ai21_chat_models.append(key)
+ else:
+ ai21_models.append(key)
+ elif value.get("litellm_provider") == "nlp_cloud":
+ nlp_cloud_models.append(key)
+ elif value.get("litellm_provider") == "aleph_alpha":
+ aleph_alpha_models.append(key)
+ elif value.get("litellm_provider") == "bedrock":
+ bedrock_models.append(key)
+ elif value.get("litellm_provider") == "deepinfra":
+ deepinfra_models.append(key)
+ elif value.get("litellm_provider") == "perplexity":
+ perplexity_models.append(key)
+ elif value.get("litellm_provider") == "watsonx":
+ watsonx_models.append(key)
+ elif value.get("litellm_provider") == "gemini":
+ gemini_models.append(key)
+ elif value.get("litellm_provider") == "fireworks_ai":
+ # ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
+ if "-to-" not in key:
+ fireworks_ai_models.append(key)
+
+
+add_known_models()
# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
openai_compatible_endpoints: List = [
"api.perplexity.ai",
@@ -960,7 +968,7 @@ from .llms.nvidia_nim import NvidiaNimConfig
from .llms.cerebras.chat import CerebrasConfig
from .llms.sambanova.chat import SambanovaConfig
from .llms.AI21.chat import AI21ChatConfig
-from .llms.fireworks_ai import FireworksAIConfig
+from .llms.fireworks_ai.chat.fireworks_ai_transformation import FireworksAIConfig
from .llms.volcengine import VolcEngineConfig
from .llms.text_completion_codestral import MistralTextCompletionConfig
from .llms.AzureOpenAI.azure import (
diff --git a/litellm/_redis.py b/litellm/_redis.py
index faa98e648..152f7f09e 100644
--- a/litellm/_redis.py
+++ b/litellm/_redis.py
@@ -12,12 +12,13 @@ import json
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
import os
-from typing import List, Optional
+from typing import List, Optional, Union
import redis # type: ignore
import redis.asyncio as async_redis # type: ignore
import litellm
+from litellm import get_secret
from ._logging import verbose_logger
@@ -83,7 +84,7 @@ def _redis_kwargs_from_environment():
return_dict = {}
for k, v in mapping.items():
- value = litellm.get_secret(k, default_value=None) # check os.environ/key vault
+ value = get_secret(k, default_value=None) # type: ignore
if value is not None:
return_dict[v] = value
return return_dict
@@ -116,7 +117,7 @@ def _get_redis_client_logic(**env_overrides):
for k, v in env_overrides.items():
if isinstance(v, str) and v.startswith("os.environ/"):
v = v.replace("os.environ/", "")
- value = litellm.get_secret(v)
+ value = get_secret(v) # type: ignore
env_overrides[k] = value
redis_kwargs = {
@@ -124,13 +125,27 @@ def _get_redis_client_logic(**env_overrides):
**env_overrides,
}
- _startup_nodes = redis_kwargs.get("startup_nodes", None) or litellm.get_secret(
+ _startup_nodes: Optional[Union[str, list]] = redis_kwargs.get("startup_nodes", None) or get_secret( # type: ignore
"REDIS_CLUSTER_NODES"
)
- if _startup_nodes is not None:
+ if _startup_nodes is not None and isinstance(_startup_nodes, str):
redis_kwargs["startup_nodes"] = json.loads(_startup_nodes)
+ _sentinel_nodes: Optional[Union[str, list]] = redis_kwargs.get("sentinel_nodes", None) or get_secret( # type: ignore
+ "REDIS_SENTINEL_NODES"
+ )
+
+ if _sentinel_nodes is not None and isinstance(_sentinel_nodes, str):
+ redis_kwargs["sentinel_nodes"] = json.loads(_sentinel_nodes)
+
+ _service_name: Optional[str] = redis_kwargs.get("service_name", None) or get_secret( # type: ignore
+ "REDIS_SERVICE_NAME"
+ )
+
+ if _service_name is not None:
+ redis_kwargs["service_name"] = _service_name
+
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
redis_kwargs.pop("host", None)
redis_kwargs.pop("port", None)
@@ -138,14 +153,19 @@ def _get_redis_client_logic(**env_overrides):
redis_kwargs.pop("password", None)
elif "startup_nodes" in redis_kwargs and redis_kwargs["startup_nodes"] is not None:
pass
+ elif (
+ "sentinel_nodes" in redis_kwargs and redis_kwargs["sentinel_nodes"] is not None
+ ):
+ pass
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
raise ValueError("Either 'host' or 'url' must be specified for redis.")
+
# litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
return redis_kwargs
def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
- _redis_cluster_nodes_in_env = litellm.get_secret("REDIS_CLUSTER_NODES")
+ _redis_cluster_nodes_in_env: Optional[str] = get_secret("REDIS_CLUSTER_NODES") # type: ignore
if _redis_cluster_nodes_in_env is not None:
try:
redis_kwargs["startup_nodes"] = json.loads(_redis_cluster_nodes_in_env)
@@ -174,6 +194,44 @@ def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs)
+def _init_redis_sentinel(redis_kwargs) -> redis.Redis:
+ sentinel_nodes = redis_kwargs.get("sentinel_nodes")
+ service_name = redis_kwargs.get("service_name")
+
+ if not sentinel_nodes or not service_name:
+ raise ValueError(
+ "Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel."
+ )
+
+ verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.")
+
+ # Set up the Sentinel client
+ sentinel = redis.Sentinel(sentinel_nodes, socket_timeout=0.1)
+
+ # Return the master instance for the given service
+
+ return sentinel.master_for(service_name)
+
+
+def _init_async_redis_sentinel(redis_kwargs) -> async_redis.Redis:
+ sentinel_nodes = redis_kwargs.get("sentinel_nodes")
+ service_name = redis_kwargs.get("service_name")
+
+ if not sentinel_nodes or not service_name:
+ raise ValueError(
+ "Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel."
+ )
+
+ verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.")
+
+ # Set up the Sentinel client
+ sentinel = async_redis.Sentinel(sentinel_nodes, socket_timeout=0.1)
+
+ # Return the master instance for the given service
+
+ return sentinel.master_for(service_name)
+
+
def get_redis_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
@@ -185,12 +243,13 @@ def get_redis_client(**env_overrides):
return redis.Redis.from_url(**url_kwargs)
- if (
- "startup_nodes" in redis_kwargs
- or litellm.get_secret("REDIS_CLUSTER_NODES") is not None
- ):
+ if "startup_nodes" in redis_kwargs or get_secret("REDIS_CLUSTER_NODES") is not None: # type: ignore
return init_redis_cluster(redis_kwargs)
+ # Check for Redis Sentinel
+ if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs:
+ return _init_redis_sentinel(redis_kwargs)
+
return redis.Redis(**redis_kwargs)
@@ -203,7 +262,7 @@ def get_redis_async_client(**env_overrides):
if arg in args:
url_kwargs[arg] = redis_kwargs[arg]
else:
- litellm.print_verbose(
+ verbose_logger.debug(
"REDIS: ignoring argument: {}. Not an allowed async_redis.Redis.from_url arg.".format(
arg
)
@@ -225,9 +284,13 @@ def get_redis_async_client(**env_overrides):
new_startup_nodes.append(ClusterNode(**item))
redis_kwargs.pop("startup_nodes")
return async_redis.RedisCluster(
- startup_nodes=new_startup_nodes, **cluster_kwargs
+ startup_nodes=new_startup_nodes, **cluster_kwargs # type: ignore
)
+ # Check for Redis Sentinel
+ if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs:
+ return _init_async_redis_sentinel(redis_kwargs)
+
return async_redis.Redis(
socket_timeout=5,
**redis_kwargs,
diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py
index 0a935a290..a176190d0 100644
--- a/litellm/cost_calculator.py
+++ b/litellm/cost_calculator.py
@@ -25,6 +25,9 @@ from litellm.llms.anthropic.cost_calculation import (
from litellm.llms.databricks.cost_calculator import (
cost_per_token as databricks_cost_per_token,
)
+from litellm.llms.fireworks_ai.cost_calculator import (
+ cost_per_token as fireworks_ai_cost_per_token,
+)
from litellm.rerank_api.types import RerankResponse
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
@@ -217,6 +220,8 @@ def cost_per_token(
return anthropic_cost_per_token(model=model, usage=usage_block)
elif custom_llm_provider == "databricks":
return databricks_cost_per_token(model=model, usage=usage_block)
+ elif custom_llm_provider == "fireworks_ai":
+ return fireworks_ai_cost_per_token(model=model, usage=usage_block)
elif custom_llm_provider == "gemini":
return google_cost_per_token(
model=model_without_prefix,
diff --git a/litellm/llms/AzureOpenAI/azure.py b/litellm/llms/AzureOpenAI/azure.py
index be3e2cbee..aee070c58 100644
--- a/litellm/llms/AzureOpenAI/azure.py
+++ b/litellm/llms/AzureOpenAI/azure.py
@@ -13,18 +13,14 @@ from pydantic import BaseModel
from typing_extensions import overload
import litellm
-from litellm import ImageResponse, OpenAIConfig
from litellm.caching import DualCache
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.utils import FileTypes # type: ignore
from litellm.types.utils import EmbeddingResponse
from litellm.utils import (
- Choices,
CustomStreamWrapper,
- Message,
ModelResponse,
- TranscriptionResponse,
UnsupportedParamsError,
convert_to_model_response_object,
get_secret,
@@ -674,7 +670,7 @@ class AzureChatCompletion(BaseLLM):
logging_obj=logging_obj,
convert_tool_call_to_json_mode=json_mode,
)
- elif "stream" in optional_params and optional_params["stream"] == True:
+ elif "stream" in optional_params and optional_params["stream"] is True:
return self.streaming(
logging_obj=logging_obj,
api_base=api_base,
@@ -725,7 +721,11 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
- if client is None or dynamic_params:
+ if (
+ client is None
+ or not isinstance(client, AzureOpenAI)
+ or dynamic_params
+ ):
azure_client = AzureOpenAI(**azure_client_params)
else:
azure_client = client
@@ -824,7 +824,10 @@ class AzureChatCompletion(BaseLLM):
input=data["messages"],
api_key=azure_client.api_key,
additional_args={
- "headers": {"Authorization": f"Bearer {azure_client.api_key}"},
+ "headers": {
+ "api_key": api_key,
+ "azure_ad_token": azure_ad_token,
+ },
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
@@ -930,7 +933,10 @@ class AzureChatCompletion(BaseLLM):
input=data["messages"],
api_key=azure_client.api_key,
additional_args={
- "headers": {"Authorization": f"Bearer {azure_client.api_key}"},
+ "headers": {
+ "api_key": api_key,
+ "azure_ad_token": azure_ad_token,
+ },
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
@@ -988,7 +994,10 @@ class AzureChatCompletion(BaseLLM):
input=data["messages"],
api_key=azure_client.api_key,
additional_args={
- "headers": {"Authorization": f"Bearer {azure_client.api_key}"},
+ "headers": {
+ "api_key": api_key,
+ "azure_ad_token": azure_ad_token,
+ },
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
@@ -1567,12 +1576,11 @@ class AzureChatCompletion(BaseLLM):
# return response
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
except AzureOpenAIError as e:
- exception_mapping_worked = True
raise e
except Exception as e:
- if hasattr(e, "status_code"):
- _status_code = getattr(e, "status_code")
- raise AzureOpenAIError(status_code=_status_code, message=str(e))
+ error_code = getattr(e, "status_code", None)
+ if error_code is not None:
+ raise AzureOpenAIError(status_code=error_code, message=str(e))
else:
raise AzureOpenAIError(status_code=500, message=str(e))
diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py
index e63bd3f54..58fdf065e 100644
--- a/litellm/llms/custom_httpx/http_handler.py
+++ b/litellm/llms/custom_httpx/http_handler.py
@@ -4,6 +4,7 @@ import traceback
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
import httpx
+from httpx import USE_CLIENT_DEFAULT
import litellm
@@ -76,9 +77,20 @@ class AsyncHTTPHandler:
await self.client.aclose()
async def get(
- self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
+ self,
+ url: str,
+ params: Optional[dict] = None,
+ headers: Optional[dict] = None,
+ follow_redirects: Optional[bool] = None,
):
- response = await self.client.get(url, params=params, headers=headers)
+ # Set follow_redirects to UseClientDefault if None
+ _follow_redirects = (
+ follow_redirects if follow_redirects is not None else USE_CLIENT_DEFAULT
+ )
+
+ response = await self.client.get(
+ url, params=params, headers=headers, follow_redirects=_follow_redirects # type: ignore
+ )
return response
async def post(
@@ -117,8 +129,9 @@ class AsyncHTTPHandler:
await new_client.aclose()
except httpx.TimeoutException as e:
headers = {}
- if hasattr(e, "response") and e.response is not None:
- for key, value in e.response.headers.items():
+ error_response = getattr(e, "response", None)
+ if error_response is not None:
+ for key, value in error_response.headers.items():
headers["response_headers-{}".format(key)] = value
raise litellm.Timeout(
@@ -173,8 +186,9 @@ class AsyncHTTPHandler:
await new_client.aclose()
except httpx.TimeoutException as e:
headers = {}
- if hasattr(e, "response") and e.response is not None:
- for key, value in e.response.headers.items():
+ error_response = getattr(e, "response", None)
+ if error_response is not None:
+ for key, value in error_response.headers.items():
headers["response_headers-{}".format(key)] = value
raise litellm.Timeout(
@@ -303,9 +317,20 @@ class HTTPHandler:
self.client.close()
def get(
- self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
+ self,
+ url: str,
+ params: Optional[dict] = None,
+ headers: Optional[dict] = None,
+ follow_redirects: Optional[bool] = None,
):
- response = self.client.get(url, params=params, headers=headers)
+ # Set follow_redirects to UseClientDefault if None
+ _follow_redirects = (
+ follow_redirects if follow_redirects is not None else USE_CLIENT_DEFAULT
+ )
+
+ response = self.client.get(
+ url, params=params, headers=headers, follow_redirects=_follow_redirects # type: ignore
+ )
return response
def post(
diff --git a/litellm/llms/databricks/chat.py b/litellm/llms/databricks/chat.py
index 343cdd3ff..23c780746 100644
--- a/litellm/llms/databricks/chat.py
+++ b/litellm/llms/databricks/chat.py
@@ -244,6 +244,34 @@ class DatabricksChatCompletion(BaseLLM):
# makes headers for API call
+ def _get_databricks_credentials(
+ self, api_key: Optional[str], api_base: Optional[str], headers: Optional[dict]
+ ) -> Tuple[str, dict]:
+ headers = headers or {"Content-Type": "application/json"}
+ try:
+ from databricks.sdk import WorkspaceClient
+
+ databricks_client = WorkspaceClient()
+ api_base = api_base or f"{databricks_client.config.host}/serving-endpoints"
+
+ if api_key is None:
+ databricks_auth_headers: dict[str, str] = (
+ databricks_client.config.authenticate()
+ )
+ headers = {**databricks_auth_headers, **headers}
+
+ return api_base, headers
+ except ImportError:
+ raise DatabricksError(
+ status_code=400,
+ message=(
+ "If the Databricks base URL and API key are not set, the databricks-sdk "
+ "Python library must be installed. Please install the databricks-sdk, set "
+ "{LLM_PROVIDER}_API_BASE and {LLM_PROVIDER}_API_KEY environment variables, "
+ "or provide the base URL and API key as arguments."
+ ),
+ )
+
def _validate_environment(
self,
api_key: Optional[str],
@@ -253,16 +281,26 @@ class DatabricksChatCompletion(BaseLLM):
headers: Optional[dict],
) -> Tuple[str, dict]:
if api_key is None and headers is None:
- raise DatabricksError(
- status_code=400,
- message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
- )
+ if custom_endpoint:
+ raise DatabricksError(
+ status_code=400,
+ message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
+ )
+ else:
+ api_base, headers = self._get_databricks_credentials(
+ api_base=api_base, api_key=api_key, headers=headers
+ )
if api_base is None:
- raise DatabricksError(
- status_code=400,
- message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
- )
+ if custom_endpoint:
+ raise DatabricksError(
+ status_code=400,
+ message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
+ )
+ else:
+ api_base, headers = self._get_databricks_credentials(
+ api_base=api_base, api_key=api_key, headers=headers
+ )
if headers is None:
headers = {
@@ -273,6 +311,9 @@ class DatabricksChatCompletion(BaseLLM):
if api_key is not None:
headers.update({"Authorization": "Bearer {}".format(api_key)})
+ if api_key is not None:
+ headers["Authorization"] = f"Bearer {api_key}"
+
if endpoint_type == "chat_completions" and custom_endpoint is not True:
api_base = "{}/chat/completions".format(api_base)
elif endpoint_type == "embeddings" and custom_endpoint is not True:
@@ -520,7 +561,8 @@ class DatabricksChatCompletion(BaseLLM):
response_json = response.json()
except httpx.HTTPStatusError as e:
raise DatabricksError(
- status_code=e.response.status_code, message=e.response.text
+ status_code=e.response.status_code,
+ message=e.response.text,
)
except httpx.TimeoutException as e:
raise DatabricksError(
diff --git a/litellm/llms/fireworks_ai.py b/litellm/llms/fireworks_ai/chat/fireworks_ai_transformation.py
similarity index 99%
rename from litellm/llms/fireworks_ai.py
rename to litellm/llms/fireworks_ai/chat/fireworks_ai_transformation.py
index b6511689e..661c8352f 100644
--- a/litellm/llms/fireworks_ai.py
+++ b/litellm/llms/fireworks_ai/chat/fireworks_ai_transformation.py
@@ -1,8 +1,6 @@
import types
from typing import Literal, Optional, Union
-import litellm
-
class FireworksAIConfig:
"""
diff --git a/litellm/llms/fireworks_ai/cost_calculator.py b/litellm/llms/fireworks_ai/cost_calculator.py
new file mode 100644
index 000000000..83ce97d4c
--- /dev/null
+++ b/litellm/llms/fireworks_ai/cost_calculator.py
@@ -0,0 +1,72 @@
+"""
+For calculating cost of fireworks ai serverless inference models.
+"""
+
+from typing import Tuple
+
+from litellm.types.utils import Usage
+from litellm.utils import get_model_info
+
+
+# Extract the number of billion parameters from the model name
+# only used for together_computer LLMs
+def get_model_params_and_category(model_name: str) -> str:
+ """
+ Helper function for calculating together ai pricing.
+
+ Returns:
+ - str: model pricing category if mapped else received model name
+ """
+ import re
+
+ model_name = model_name.lower()
+
+ # Check for MoE models in the form xb
+ moe_match = re.search(r"(\d+)x(\d+)b", model_name)
+ if moe_match:
+ total_billion = int(moe_match.group(1)) * int(moe_match.group(2))
+ if total_billion <= 56:
+ return "fireworks-ai-moe-up-to-56b"
+ elif total_billion <= 176:
+ return "fireworks-ai-56b-to-176b"
+
+ # Check for standard models in the form b
+ re_params_match = re.search(r"(\d+)b", model_name)
+ if re_params_match is not None:
+ params_match = str(re_params_match.group(1))
+ params_billion = float(params_match)
+
+ # Determine the category based on the number of parameters
+ if params_billion <= 16.0:
+ return "fireworks-ai-up-to-16b"
+ elif params_billion <= 80.0:
+ return "fireworks-ai-16b-80b"
+
+ # If no matches, return the original model_name
+ return model_name
+
+
+def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
+ """
+ Calculates the cost per token for a given model, prompt tokens, and completion tokens.
+
+ Input:
+ - model: str, the model name without provider prefix
+ - usage: LiteLLM Usage block, containing anthropic caching information
+
+ Returns:
+ Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
+ """
+ base_model = get_model_params_and_category(model_name=model)
+
+ ## GET MODEL INFO
+ model_info = get_model_info(model=base_model, custom_llm_provider="fireworks_ai")
+
+ ## CALCULATE INPUT COST
+
+ prompt_cost: float = usage["prompt_tokens"] * model_info["input_cost_per_token"]
+
+ ## CALCULATE OUTPUT COST
+ completion_cost = usage["completion_tokens"] * model_info["output_cost_per_token"]
+
+ return prompt_cost, completion_cost
diff --git a/litellm/llms/prompt_templates/image_handling.py b/litellm/llms/prompt_templates/image_handling.py
index 90db3dedc..c91c05df9 100644
--- a/litellm/llms/prompt_templates/image_handling.py
+++ b/litellm/llms/prompt_templates/image_handling.py
@@ -7,6 +7,7 @@ import base64
from httpx import Response
import litellm
+from litellm import verbose_logger
from litellm.caching import InMemoryCache
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
@@ -58,7 +59,7 @@ async def async_convert_url_to_base64(url: str) -> str:
client = litellm.module_level_aclient
for _ in range(3):
try:
- response = await client.get(url)
+ response = await client.get(url, follow_redirects=True)
return _process_image_response(response, url)
except:
pass
@@ -75,9 +76,11 @@ def convert_url_to_base64(url: str) -> str:
client = litellm.module_level_client
for _ in range(3):
try:
- response = client.get(url)
+ response = client.get(url, follow_redirects=True)
return _process_image_response(response, url)
- except:
+ except Exception as e:
+ verbose_logger.exception(e)
+ # print(e)
pass
raise Exception(
f"Error: Unable to fetch image from URL after 3 attempts. url={url}"
diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json
index dbd9bd73b..23389d530 100644
--- a/litellm/model_prices_and_context_window_backup.json
+++ b/litellm/model_prices_and_context_window_backup.json
@@ -5452,6 +5452,26 @@
"mode": "chat",
"supports_function_calling": true,
"source": "https://fireworks.ai/pricing"
+ },
+ "fireworks-ai-up-to-16b": {
+ "input_cost_per_token": 0.0000002,
+ "output_cost_per_token": 0.0000002,
+ "litellm_provider": "fireworks_ai"
+ },
+ "fireworks-ai-16.1b-to-80b": {
+ "input_cost_per_token": 0.0000009,
+ "output_cost_per_token": 0.0000009,
+ "litellm_provider": "fireworks_ai"
+ },
+ "fireworks-ai-moe-up-to-56b": {
+ "input_cost_per_token": 0.0000005,
+ "output_cost_per_token": 0.0000005,
+ "litellm_provider": "fireworks_ai"
+ },
+ "fireworks-ai-56b-to-176b": {
+ "input_cost_per_token": 0.0000012,
+ "output_cost_per_token": 0.0000012,
+ "litellm_provider": "fireworks_ai"
},
"anyscale/mistralai/Mistral-7B-Instruct-v0.1": {
"max_tokens": 16384,
diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml
index a2911dca6..5773f9f51 100644
--- a/litellm/proxy/_new_secret_config.yaml
+++ b/litellm/proxy/_new_secret_config.yaml
@@ -19,4 +19,11 @@ model_list:
- model_name: o1-preview
litellm_params:
model: o1-preview
-
\ No newline at end of file
+
+litellm_settings:
+ cache: true
+ # cache_params:
+ # type: "redis"
+ # service_name: "mymaster"
+ # sentinel_nodes:
+ # - ["localhost", 26379]
\ No newline at end of file
diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py
index 20dab118b..613dacb57 100644
--- a/litellm/proxy/proxy_server.py
+++ b/litellm/proxy/proxy_server.py
@@ -653,38 +653,19 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
return
try:
- from azure.identity import ClientSecretCredential, DefaultAzureCredential
+ from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets import SecretClient
# Set your Azure Key Vault URI
KVUri = os.getenv("AZURE_KEY_VAULT_URI", None)
- # Set your Azure AD application/client ID, client secret, and tenant ID
- client_id = os.getenv("AZURE_CLIENT_ID", None)
- client_secret = os.getenv("AZURE_CLIENT_SECRET", None)
- tenant_id = os.getenv("AZURE_TENANT_ID", None)
+ credential = DefaultAzureCredential()
- if (
- KVUri is not None
- and client_id is not None
- and client_secret is not None
- and tenant_id is not None
- ):
- # Initialize the ClientSecretCredential
- # credential = ClientSecretCredential(
- # client_id=client_id, client_secret=client_secret, tenant_id=tenant_id
- # )
- credential = DefaultAzureCredential()
+ # Create the SecretClient using the credential
+ client = SecretClient(vault_url=KVUri, credential=credential)
- # Create the SecretClient using the credential
- client = SecretClient(vault_url=KVUri, credential=credential)
-
- litellm.secret_manager_client = client
- litellm._key_management_system = KeyManagementSystem.AZURE_KEY_VAULT
- else:
- raise Exception(
- f"Missing KVUri or client_id or client_secret or tenant_id from environment"
- )
+ litellm.secret_manager_client = client
+ litellm._key_management_system = KeyManagementSystem.AZURE_KEY_VAULT
except Exception as e:
_error_str = str(e)
verbose_proxy_logger.exception(
@@ -1626,8 +1607,8 @@ class ProxyConfig:
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = `, _redis.py checks for REDIS specific environment variables
self._init_cache(cache_params=cache_params)
if litellm.cache is not None:
- verbose_proxy_logger.debug( # noqa
- f"{blue_color_code}Set Cache on LiteLLM Proxy= {vars(litellm.cache.cache)}{vars(litellm.cache)}{reset_color_code}"
+ verbose_proxy_logger.debug(
+ f"{blue_color_code}Set Cache on LiteLLM Proxy{reset_color_code}"
)
elif key == "cache" and value is False:
pass
diff --git a/litellm/router.py b/litellm/router.py
index 2a3f583fa..780eeb3e7 100644
--- a/litellm/router.py
+++ b/litellm/router.py
@@ -4019,7 +4019,9 @@ class Router:
_model_info=_model_info,
)
- verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}")
+ verbose_router_logger.debug(
+ f"\nInitialized Model List {self.get_model_names()}"
+ )
self.model_names = [m["model_name"] for m in model_list]
def _add_deployment(self, deployment: Deployment) -> Deployment:
@@ -4630,24 +4632,25 @@ class Router:
if hasattr(self, "model_list"):
returned_models: List[DeploymentTypedDict] = []
- for model_alias, model_value in self.model_group_alias.items():
+ if hasattr(self, "model_group_alias"):
+ for model_alias, model_value in self.model_group_alias.items():
- if isinstance(model_value, str):
- _router_model_name: str = model_value
- elif isinstance(model_value, dict):
- _model_value = RouterModelGroupAliasItem(**model_value) # type: ignore
- if _model_value["hidden"] is True:
- continue
+ if isinstance(model_value, str):
+ _router_model_name: str = model_value
+ elif isinstance(model_value, dict):
+ _model_value = RouterModelGroupAliasItem(**model_value) # type: ignore
+ if _model_value["hidden"] is True:
+ continue
+ else:
+ _router_model_name = _model_value["model"]
else:
- _router_model_name = _model_value["model"]
- else:
- continue
+ continue
- returned_models.extend(
- self._get_all_deployments(
- model_name=_router_model_name, model_alias=model_alias
+ returned_models.extend(
+ self._get_all_deployments(
+ model_name=_router_model_name, model_alias=model_alias
+ )
)
- )
if model_name is None:
returned_models += self.model_list
@@ -5030,7 +5033,7 @@ class Router:
# return the first deployment where the `model` matches the specificed deployment name
return deployment_model, deployment
raise ValueError(
- f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}"
+ f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.get_model_names()}"
)
elif model in self.get_model_ids():
deployment = self.get_model_info(id=model)
diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py
index a52f17847..5244fa4b7 100644
--- a/litellm/tests/test_caching.py
+++ b/litellm/tests/test_caching.py
@@ -2045,3 +2045,57 @@ async def test_proxy_logging_setup():
pl_obj = ProxyLogging(user_api_key_cache=DualCache())
assert pl_obj.internal_usage_cache.always_read_redis is True
+
+
+@pytest.mark.skip(reason="local test. Requires sentinel setup.")
+@pytest.mark.asyncio
+async def test_redis_sentinel_caching():
+ """
+ Init redis client
+ - write to client
+ - read from client
+ """
+ litellm.set_verbose = False
+
+ random_number = random.randint(
+ 1, 100000
+ ) # add a random number to ensure it's always adding / reading from cache
+ messages = [
+ {"role": "user", "content": f"write a one sentence poem about: {random_number}"}
+ ]
+
+ litellm.cache = Cache(
+ type="redis",
+ # host=os.environ["REDIS_HOST"],
+ # port=os.environ["REDIS_PORT"],
+ # password=os.environ["REDIS_PASSWORD"],
+ service_name="mymaster",
+ sentinel_nodes=[("localhost", 26379)],
+ )
+ response1 = completion(
+ model="gpt-3.5-turbo",
+ messages=messages,
+ )
+
+ cache_key = litellm.cache.get_cache_key(
+ model="gpt-3.5-turbo",
+ messages=messages,
+ )
+ print(f"cache_key: {cache_key}")
+ litellm.cache.add_cache(result=response1, cache_key=cache_key)
+ print(f"cache key pre async get: {cache_key}")
+ stored_val = litellm.cache.get_cache(
+ model="gpt-3.5-turbo",
+ messages=messages,
+ )
+
+ print(f"stored_val: {stored_val}")
+ assert stored_val["id"] == response1.id
+
+ stored_val_2 = await litellm.cache.async_get_cache(
+ model="gpt-3.5-turbo",
+ messages=messages,
+ )
+
+ print(f"stored_val: {stored_val}")
+ assert stored_val_2["id"] == response1.id
diff --git a/litellm/tests/test_completion_cost.py b/litellm/tests/test_completion_cost.py
index 835cced45..82257fad8 100644
--- a/litellm/tests/test_completion_cost.py
+++ b/litellm/tests/test_completion_cost.py
@@ -1255,3 +1255,16 @@ def test_completion_cost_databricks_embedding(model):
print(resp)
cost = completion_cost(completion_response=resp)
+
+
+def test_completion_cost_fireworks_ai():
+ os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
+ litellm.model_cost = litellm.get_model_cost_map(url="")
+
+ messages = [{"role": "user", "content": "Hey, how's it going?"}]
+ resp = litellm.completion(
+ model="fireworks_ai/mixtral-8x7b-instruct", messages=messages
+ ) # works fine
+
+ print(resp)
+ cost = completion_cost(completion_response=resp)
diff --git a/litellm/tests/test_prompt_factory.py b/litellm/tests/test_prompt_factory.py
index 4c99efb3e..5f285e00a 100644
--- a/litellm/tests/test_prompt_factory.py
+++ b/litellm/tests/test_prompt_factory.py
@@ -432,3 +432,7 @@ def test_vertex_only_image_user_message():
), "Invalid gemini input. Got={}, Expected={}".format(
content, expected_response[idx]
)
+
+
+def test_convert_url():
+ convert_url_to_base64("https://picsum.photos/id/237/200/300")
diff --git a/litellm/utils.py b/litellm/utils.py
index 7c5ef4248..7d6d5223c 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -6830,7 +6830,10 @@ def exception_type(
llm_provider=custom_llm_provider,
model=model,
)
- elif original_exception.status_code == 401:
+ elif (
+ original_exception.status_code == 401
+ or original_exception.status_code == 403
+ ):
exception_mapping_worked = True
raise AuthenticationError(
message=f"{custom_llm_provider}Exception - {original_exception.message}",
diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json
index dbd9bd73b..23389d530 100644
--- a/model_prices_and_context_window.json
+++ b/model_prices_and_context_window.json
@@ -5452,6 +5452,26 @@
"mode": "chat",
"supports_function_calling": true,
"source": "https://fireworks.ai/pricing"
+ },
+ "fireworks-ai-up-to-16b": {
+ "input_cost_per_token": 0.0000002,
+ "output_cost_per_token": 0.0000002,
+ "litellm_provider": "fireworks_ai"
+ },
+ "fireworks-ai-16.1b-to-80b": {
+ "input_cost_per_token": 0.0000009,
+ "output_cost_per_token": 0.0000009,
+ "litellm_provider": "fireworks_ai"
+ },
+ "fireworks-ai-moe-up-to-56b": {
+ "input_cost_per_token": 0.0000005,
+ "output_cost_per_token": 0.0000005,
+ "litellm_provider": "fireworks_ai"
+ },
+ "fireworks-ai-56b-to-176b": {
+ "input_cost_per_token": 0.0000012,
+ "output_cost_per_token": 0.0000012,
+ "litellm_provider": "fireworks_ai"
},
"anyscale/mistralai/Mistral-7B-Instruct-v0.1": {
"max_tokens": 16384,
diff --git a/tests/llm_translation/test_databricks.py b/tests/llm_translation/test_databricks.py
index 067b188ed..b3bd92d8d 100644
--- a/tests/llm_translation/test_databricks.py
+++ b/tests/llm_translation/test_databricks.py
@@ -7,10 +7,17 @@ from typing import Any, Dict, List
from unittest.mock import MagicMock, Mock, patch
import litellm
-from litellm.exceptions import BadRequestError, InternalServerError
+from litellm.exceptions import BadRequestError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import CustomStreamWrapper
+try:
+ import databricks.sdk
+
+ databricks_sdk_installed = True
+except ImportError:
+ databricks_sdk_installed = False
+
def mock_chat_response() -> Dict[str, Any]:
return {
@@ -33,8 +40,8 @@ def mock_chat_response() -> Dict[str, Any]:
"usage": {
"prompt_tokens": 230,
"completion_tokens": 38,
- "total_tokens": 268,
"completion_tokens_details": None,
+ "total_tokens": 268,
},
"system_fingerprint": None,
}
@@ -195,7 +202,14 @@ def mock_embedding_response() -> Dict[str, Any]:
@pytest.mark.parametrize("set_base", [True, False])
-def test_throws_if_only_one_of_api_base_or_api_key_set(monkeypatch, set_base):
+def test_throws_if_api_base_or_api_key_not_set_without_databricks_sdk(
+ monkeypatch, set_base
+):
+ # Simulate that the databricks SDK is not installed
+ monkeypatch.setitem(sys.modules, "databricks.sdk", None)
+
+ err_msg = "the Databricks base URL and API key are not set"
+
if set_base:
monkeypatch.setenv(
"DATABRICKS_API_BASE",
@@ -204,11 +218,11 @@ def test_throws_if_only_one_of_api_base_or_api_key_set(monkeypatch, set_base):
monkeypatch.delenv(
"DATABRICKS_API_KEY",
)
- err_msg = "A call is being made to LLM Provider but no key is set"
else:
monkeypatch.setenv("DATABRICKS_API_KEY", "dapimykey")
- monkeypatch.delenv("DATABRICKS_API_BASE")
- err_msg = "A call is being made to LLM Provider but no api base is set"
+ monkeypatch.delenv(
+ "DATABRICKS_API_BASE",
+ )
with pytest.raises(BadRequestError) as exc:
litellm.completion(
@@ -422,6 +436,67 @@ def test_completions_streaming_with_async_http_handler(monkeypatch):
)
+@pytest.mark.skipif(not databricks_sdk_installed, reason="Databricks SDK not installed")
+def test_completions_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch):
+ from databricks.sdk import WorkspaceClient
+ from databricks.sdk.config import Config
+
+ sync_handler = HTTPHandler()
+ mock_response = Mock(spec=httpx.Response)
+ mock_response.status_code = 200
+ mock_response.json.return_value = mock_chat_response()
+
+ expected_response_json = {
+ **mock_chat_response(),
+ **{
+ "model": "databricks/dbrx-instruct-071224",
+ },
+ }
+
+ base_url = "https://my.workspace.cloud.databricks.com"
+ api_key = "dapimykey"
+ headers = {
+ "Authorization": f"Bearer {api_key}",
+ }
+ messages = [{"role": "user", "content": "How are you?"}]
+
+ mock_workspace_client: WorkspaceClient = MagicMock()
+ mock_config: Config = MagicMock()
+ # Simulate the behavior of the config property and its methods
+ mock_config.authenticate.side_effect = lambda: headers
+ mock_config.host = base_url # Assign directly as if it's a property
+ mock_workspace_client.config = mock_config
+
+ with patch(
+ "databricks.sdk.WorkspaceClient", return_value=mock_workspace_client
+ ), patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
+ response = litellm.completion(
+ model="databricks/dbrx-instruct-071224",
+ messages=messages,
+ client=sync_handler,
+ temperature=0.5,
+ extraparam="testpassingextraparam",
+ )
+ assert response.to_dict() == expected_response_json
+
+ mock_post.assert_called_once_with(
+ f"{base_url}/serving-endpoints/chat/completions",
+ headers={
+ "Authorization": f"Bearer {api_key}",
+ "Content-Type": "application/json",
+ },
+ data=json.dumps(
+ {
+ "model": "dbrx-instruct-071224",
+ "messages": messages,
+ "temperature": 0.5,
+ "extraparam": "testpassingextraparam",
+ "stream": False,
+ }
+ ),
+ )
+
+
def test_embeddings_with_sync_http_handler(monkeypatch):
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
api_key = "dapimykey"
@@ -500,3 +575,59 @@ def test_embeddings_with_async_http_handler(monkeypatch):
}
),
)
+
+
+@pytest.mark.skipif(not databricks_sdk_installed, reason="Databricks SDK not installed")
+def test_embeddings_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch):
+ from databricks.sdk import WorkspaceClient
+ from databricks.sdk.config import Config
+
+ base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
+ api_key = "dapimykey"
+ monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
+ monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
+
+ sync_handler = HTTPHandler()
+ mock_response = Mock(spec=httpx.Response)
+ mock_response.status_code = 200
+ mock_response.json.return_value = mock_embedding_response()
+
+ base_url = "https://my.workspace.cloud.databricks.com"
+ api_key = "dapimykey"
+ headers = {
+ "Authorization": f"Bearer {api_key}",
+ }
+ inputs = ["Hello", "World"]
+
+ mock_workspace_client: WorkspaceClient = MagicMock()
+ mock_config: Config = MagicMock()
+ # Simulate the behavior of the config property and its methods
+ mock_config.authenticate.side_effect = lambda: headers
+ mock_config.host = base_url # Assign directly as if it's a property
+ mock_workspace_client.config = mock_config
+
+ with patch(
+ "databricks.sdk.WorkspaceClient", return_value=mock_workspace_client
+ ), patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
+ response = litellm.embedding(
+ model="databricks/bge-large-en-v1.5",
+ input=inputs,
+ client=sync_handler,
+ extraparam="testpassingextraparam",
+ )
+ assert response.to_dict() == mock_embedding_response()
+
+ mock_post.assert_called_once_with(
+ f"{base_url}/serving-endpoints/embeddings",
+ headers={
+ "Authorization": f"Bearer {api_key}",
+ "Content-Type": "application/json",
+ },
+ data=json.dumps(
+ {
+ "model": "bge-large-en-v1.5",
+ "input": inputs,
+ "extraparam": "testpassingextraparam",
+ }
+ ),
+ )
diff --git a/tests/llm_translation/test_fireworks_ai_translation.py b/tests/llm_translation/test_fireworks_ai_translation.py
index c7c1f5445..00361cd18 100644
--- a/tests/llm_translation/test_fireworks_ai_translation.py
+++ b/tests/llm_translation/test_fireworks_ai_translation.py
@@ -7,7 +7,7 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
-from litellm.llms.fireworks_ai import FireworksAIConfig
+from litellm.llms.fireworks_ai.chat.fireworks_ai_transformation import FireworksAIConfig
fireworks = FireworksAIConfig()
diff --git a/tests/llm_translation/test_max_completion_tokens.py b/tests/llm_translation/test_max_completion_tokens.py
index 53f9fc761..6d5eb8e3c 100644
--- a/tests/llm_translation/test_max_completion_tokens.py
+++ b/tests/llm_translation/test_max_completion_tokens.py
@@ -149,7 +149,9 @@ def test_all_model_configs():
{"max_completion_tokens": 10}, {}, "llama3"
) == {"max_tokens": 10}
- from litellm.llms.fireworks_ai import FireworksAIConfig
+ from litellm.llms.fireworks_ai.chat.fireworks_ai_transformation import (
+ FireworksAIConfig,
+ )
assert "max_completion_tokens" in FireworksAIConfig().get_supported_openai_params()
assert FireworksAIConfig().map_openai_params(