LiteLLM Minor Fixes & Improvements (09/18/2024) (#5772)

* fix(proxy_server.py): fix azure key vault logic to not require client id/secret

* feat(cost_calculator.py): support fireworks ai cost tracking

* build(docker-compose.yml): add lines for mounting config.yaml to docker compose

Closes https://github.com/BerriAI/litellm/issues/5739

* fix(input.md): update docs to clarify litellm supports content as a list of dictionaries

Fixes https://github.com/BerriAI/litellm/issues/5755

* fix(input.md): update input.md to include all message values

* fix(image_handling.py): follow image url redirects

Fixes https://github.com/BerriAI/litellm/issues/5763

* fix(router.py): Fix model key/base leak in error message

Fixes https://github.com/BerriAI/litellm/issues/5762

* fix(http_handler.py): fix linting error

* fix(azure.py): fix logging to show azure_ad_token being used

Fixes https://github.com/BerriAI/litellm/issues/5767

* fix(_redis.py): add redis sentinel support

Closes https://github.com/BerriAI/litellm/issues/4381

* feat(_redis.py): add redis sentinel support

Closes https://github.com/BerriAI/litellm/issues/4381

* test(test_completion_cost.py): fix test

* Databricks Integration: Integrate Databricks SDK as optional mechanism for fetching API base and token, if unspecified (#5746)

* LiteLLM Minor Fixes & Improvements (09/16/2024)  (#5723)

* coverage (#5713)

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Move (#5714)

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* fix(litellm_logging.py): fix logging client re-init (#5710)

Fixes https://github.com/BerriAI/litellm/issues/5695

* fix(presidio.py): Fix logging_hook response and add support for additional presidio variables in guardrails config

Fixes https://github.com/BerriAI/litellm/issues/5682

* feat(o1_handler.py): fake streaming for openai o1 models

Fixes https://github.com/BerriAI/litellm/issues/5694

* docs: deprecated traceloop integration in favor of native otel (#5249)

* fix: fix linting errors

* fix: fix linting errors

* fix(main.py): fix o1 import

---------

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com>
Co-authored-by: Nir Gazit <nirga@users.noreply.github.com>

* feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view (#5730)

* feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view

Supports having `MonthlyGlobalSpend` view be a material view, and exposes an endpoint to refresh it

* fix(custom_logger.py): reset calltype

* fix: fix linting errors

* fix: fix linting error

* fix

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* fix: fix import

* Fix

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* fix

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* DB test

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Coverage

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* progress

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* fix

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* fix

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* fix

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* fix test name

Signed-off-by: dbczumar <corey.zumar@databricks.com>

---------

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Co-authored-by: Krish Dholakia <krrishdholakia@gmail.com>
Co-authored-by: Nir Gazit <nirga@users.noreply.github.com>

* test: fix test

* test(test_databricks.py): fix test

* fix(databricks/chat.py): handle custom endpoint (e.g. sagemaker)

* Apply code scanning fix for clear-text logging of sensitive information

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>

* fix(__init__.py): fix known fireworks ai models

---------

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com>
Co-authored-by: Nir Gazit <nirga@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
This commit is contained in:
Krish Dholakia 2024-09-19 13:25:29 -07:00 committed by GitHub
parent 49b2766723
commit d46660ea0f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 697 additions and 170 deletions

View file

@ -9,6 +9,14 @@ services:
#########################################
## Uncomment these lines to start proxy with a config.yaml file ##
# volumes:
# - ./config.yaml:/app/config.yaml <<- this is missing in the docker-compose file currently
# The below two are my suggestion
# command:
# - "--config=/app/config.yaml"
##############################################
#########################################
## Uncomment these lines to start proxy with a config.yaml file ##
# volumes:
###############################################
ports:
- "4000:4000" # Map the container port to the host, change the host port if necessary

View file

@ -124,16 +124,19 @@ def completion(
#### Properties of `messages`
*Note* - Each message in the array contains the following properties:
- `role`: *string* - The role of the message's author. Roles can be: system, user, assistant, or function.
- `role`: *string* - The role of the message's author. Roles can be: system, user, assistant, function or tool.
- `content`: *string or null* - The contents of the message. It is required for all messages, but may be null for assistant messages with function calls.
- `content`: *string or list[dict] or null* - The contents of the message. It is required for all messages, but may be null for assistant messages with function calls.
- `name`: *string (optional)* - The name of the author of the message. It is required if the role is "function". The name should match the name of the function represented in the content. It can contain characters (a-z, A-Z, 0-9), and underscores, with a maximum length of 64 characters.
- `function_call`: *object (optional)* - The name and arguments of a function that should be called, as generated by the model.
- `tool_call_id`: *str (optional)* - Tool call that this message is responding to.
[**See All Message Values**](https://github.com/BerriAI/litellm/blob/8600ec77042dacad324d3879a2bd918fc6a719fa/litellm/types/llms/openai.py#L392)
## Optional Fields
- `temperature`: *number or null (optional)* - The sampling temperature to be used, between 0 and 2. Higher values like 0.8 produce more random outputs, while lower values like 0.2 make outputs more focused and deterministic.

View file

@ -110,6 +110,60 @@ print("REDIS_CLUSTER_NODES", os.environ["REDIS_CLUSTER_NODES"])
</Tabs>
#### Redis Sentinel
<Tabs>
<TabItem value="redis-sentinel-config" label="Set on config.yaml">
```yaml
model_list:
- model_name: "*"
litellm_params:
model: "*"
litellm_settings:
cache: true
cache_params:
type: "redis"
service_name: "mymaster"
sentinel_nodes: [["localhost", 26379]]
```
</TabItem>
<TabItem value="redis-env" label="Set on .env">
You can configure redis sentinel in your .env by setting `REDIS_SENTINEL_NODES` in your .env
**Example `REDIS_SENTINEL_NODES`** value
```env
REDIS_SENTINEL_NODES='[["localhost", 26379]]'
REDIS_SERVICE_NAME = "mymaster"
```
:::note
Example python script for setting redis cluster nodes in .env:
```python
# List of startup nodes
sentinel_nodes = [["localhost", 26379]]
# set startup nodes in environment variables
os.environ["REDIS_SENTINEL_NODES"] = json.dumps(sentinel_nodes)
print("REDIS_SENTINEL_NODES", os.environ["REDIS_SENTINEL_NODES"])
```
:::
</TabItem>
</Tabs>
#### TTL
```yaml

View file

@ -382,7 +382,10 @@ deepinfra_models: List = []
perplexity_models: List = []
watsonx_models: List = []
gemini_models: List = []
for key, value in model_cost.items():
def add_known_models():
for key, value in model_cost.items():
if value.get("litellm_provider") == "openai":
open_ai_chat_completion_models.append(key)
elif value.get("litellm_provider") == "text-completion-openai":
@ -448,7 +451,12 @@ for key, value in model_cost.items():
elif value.get("litellm_provider") == "gemini":
gemini_models.append(key)
elif value.get("litellm_provider") == "fireworks_ai":
# ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
if "-to-" not in key:
fireworks_ai_models.append(key)
add_known_models()
# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
openai_compatible_endpoints: List = [
"api.perplexity.ai",
@ -960,7 +968,7 @@ from .llms.nvidia_nim import NvidiaNimConfig
from .llms.cerebras.chat import CerebrasConfig
from .llms.sambanova.chat import SambanovaConfig
from .llms.AI21.chat import AI21ChatConfig
from .llms.fireworks_ai import FireworksAIConfig
from .llms.fireworks_ai.chat.fireworks_ai_transformation import FireworksAIConfig
from .llms.volcengine import VolcEngineConfig
from .llms.text_completion_codestral import MistralTextCompletionConfig
from .llms.AzureOpenAI.azure import (

View file

@ -12,12 +12,13 @@ import json
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
import os
from typing import List, Optional
from typing import List, Optional, Union
import redis # type: ignore
import redis.asyncio as async_redis # type: ignore
import litellm
from litellm import get_secret
from ._logging import verbose_logger
@ -83,7 +84,7 @@ def _redis_kwargs_from_environment():
return_dict = {}
for k, v in mapping.items():
value = litellm.get_secret(k, default_value=None) # check os.environ/key vault
value = get_secret(k, default_value=None) # type: ignore
if value is not None:
return_dict[v] = value
return return_dict
@ -116,7 +117,7 @@ def _get_redis_client_logic(**env_overrides):
for k, v in env_overrides.items():
if isinstance(v, str) and v.startswith("os.environ/"):
v = v.replace("os.environ/", "")
value = litellm.get_secret(v)
value = get_secret(v) # type: ignore
env_overrides[k] = value
redis_kwargs = {
@ -124,13 +125,27 @@ def _get_redis_client_logic(**env_overrides):
**env_overrides,
}
_startup_nodes = redis_kwargs.get("startup_nodes", None) or litellm.get_secret(
_startup_nodes: Optional[Union[str, list]] = redis_kwargs.get("startup_nodes", None) or get_secret( # type: ignore
"REDIS_CLUSTER_NODES"
)
if _startup_nodes is not None:
if _startup_nodes is not None and isinstance(_startup_nodes, str):
redis_kwargs["startup_nodes"] = json.loads(_startup_nodes)
_sentinel_nodes: Optional[Union[str, list]] = redis_kwargs.get("sentinel_nodes", None) or get_secret( # type: ignore
"REDIS_SENTINEL_NODES"
)
if _sentinel_nodes is not None and isinstance(_sentinel_nodes, str):
redis_kwargs["sentinel_nodes"] = json.loads(_sentinel_nodes)
_service_name: Optional[str] = redis_kwargs.get("service_name", None) or get_secret( # type: ignore
"REDIS_SERVICE_NAME"
)
if _service_name is not None:
redis_kwargs["service_name"] = _service_name
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
redis_kwargs.pop("host", None)
redis_kwargs.pop("port", None)
@ -138,14 +153,19 @@ def _get_redis_client_logic(**env_overrides):
redis_kwargs.pop("password", None)
elif "startup_nodes" in redis_kwargs and redis_kwargs["startup_nodes"] is not None:
pass
elif (
"sentinel_nodes" in redis_kwargs and redis_kwargs["sentinel_nodes"] is not None
):
pass
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
raise ValueError("Either 'host' or 'url' must be specified for redis.")
# litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
return redis_kwargs
def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
_redis_cluster_nodes_in_env = litellm.get_secret("REDIS_CLUSTER_NODES")
_redis_cluster_nodes_in_env: Optional[str] = get_secret("REDIS_CLUSTER_NODES") # type: ignore
if _redis_cluster_nodes_in_env is not None:
try:
redis_kwargs["startup_nodes"] = json.loads(_redis_cluster_nodes_in_env)
@ -174,6 +194,44 @@ def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs)
def _init_redis_sentinel(redis_kwargs) -> redis.Redis:
sentinel_nodes = redis_kwargs.get("sentinel_nodes")
service_name = redis_kwargs.get("service_name")
if not sentinel_nodes or not service_name:
raise ValueError(
"Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel."
)
verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.")
# Set up the Sentinel client
sentinel = redis.Sentinel(sentinel_nodes, socket_timeout=0.1)
# Return the master instance for the given service
return sentinel.master_for(service_name)
def _init_async_redis_sentinel(redis_kwargs) -> async_redis.Redis:
sentinel_nodes = redis_kwargs.get("sentinel_nodes")
service_name = redis_kwargs.get("service_name")
if not sentinel_nodes or not service_name:
raise ValueError(
"Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel."
)
verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.")
# Set up the Sentinel client
sentinel = async_redis.Sentinel(sentinel_nodes, socket_timeout=0.1)
# Return the master instance for the given service
return sentinel.master_for(service_name)
def get_redis_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
@ -185,12 +243,13 @@ def get_redis_client(**env_overrides):
return redis.Redis.from_url(**url_kwargs)
if (
"startup_nodes" in redis_kwargs
or litellm.get_secret("REDIS_CLUSTER_NODES") is not None
):
if "startup_nodes" in redis_kwargs or get_secret("REDIS_CLUSTER_NODES") is not None: # type: ignore
return init_redis_cluster(redis_kwargs)
# Check for Redis Sentinel
if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs:
return _init_redis_sentinel(redis_kwargs)
return redis.Redis(**redis_kwargs)
@ -203,7 +262,7 @@ def get_redis_async_client(**env_overrides):
if arg in args:
url_kwargs[arg] = redis_kwargs[arg]
else:
litellm.print_verbose(
verbose_logger.debug(
"REDIS: ignoring argument: {}. Not an allowed async_redis.Redis.from_url arg.".format(
arg
)
@ -225,9 +284,13 @@ def get_redis_async_client(**env_overrides):
new_startup_nodes.append(ClusterNode(**item))
redis_kwargs.pop("startup_nodes")
return async_redis.RedisCluster(
startup_nodes=new_startup_nodes, **cluster_kwargs
startup_nodes=new_startup_nodes, **cluster_kwargs # type: ignore
)
# Check for Redis Sentinel
if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs:
return _init_async_redis_sentinel(redis_kwargs)
return async_redis.Redis(
socket_timeout=5,
**redis_kwargs,

View file

@ -25,6 +25,9 @@ from litellm.llms.anthropic.cost_calculation import (
from litellm.llms.databricks.cost_calculator import (
cost_per_token as databricks_cost_per_token,
)
from litellm.llms.fireworks_ai.cost_calculator import (
cost_per_token as fireworks_ai_cost_per_token,
)
from litellm.rerank_api.types import RerankResponse
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
@ -217,6 +220,8 @@ def cost_per_token(
return anthropic_cost_per_token(model=model, usage=usage_block)
elif custom_llm_provider == "databricks":
return databricks_cost_per_token(model=model, usage=usage_block)
elif custom_llm_provider == "fireworks_ai":
return fireworks_ai_cost_per_token(model=model, usage=usage_block)
elif custom_llm_provider == "gemini":
return google_cost_per_token(
model=model_without_prefix,

View file

@ -13,18 +13,14 @@ from pydantic import BaseModel
from typing_extensions import overload
import litellm
from litellm import ImageResponse, OpenAIConfig
from litellm.caching import DualCache
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.utils import FileTypes # type: ignore
from litellm.types.utils import EmbeddingResponse
from litellm.utils import (
Choices,
CustomStreamWrapper,
Message,
ModelResponse,
TranscriptionResponse,
UnsupportedParamsError,
convert_to_model_response_object,
get_secret,
@ -674,7 +670,7 @@ class AzureChatCompletion(BaseLLM):
logging_obj=logging_obj,
convert_tool_call_to_json_mode=json_mode,
)
elif "stream" in optional_params and optional_params["stream"] == True:
elif "stream" in optional_params and optional_params["stream"] is True:
return self.streaming(
logging_obj=logging_obj,
api_base=api_base,
@ -725,7 +721,11 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if client is None or dynamic_params:
if (
client is None
or not isinstance(client, AzureOpenAI)
or dynamic_params
):
azure_client = AzureOpenAI(**azure_client_params)
else:
azure_client = client
@ -824,7 +824,10 @@ class AzureChatCompletion(BaseLLM):
input=data["messages"],
api_key=azure_client.api_key,
additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"headers": {
"api_key": api_key,
"azure_ad_token": azure_ad_token,
},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
@ -930,7 +933,10 @@ class AzureChatCompletion(BaseLLM):
input=data["messages"],
api_key=azure_client.api_key,
additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"headers": {
"api_key": api_key,
"azure_ad_token": azure_ad_token,
},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
@ -988,7 +994,10 @@ class AzureChatCompletion(BaseLLM):
input=data["messages"],
api_key=azure_client.api_key,
additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"headers": {
"api_key": api_key,
"azure_ad_token": azure_ad_token,
},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
@ -1567,12 +1576,11 @@ class AzureChatCompletion(BaseLLM):
# return response
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if hasattr(e, "status_code"):
_status_code = getattr(e, "status_code")
raise AzureOpenAIError(status_code=_status_code, message=str(e))
error_code = getattr(e, "status_code", None)
if error_code is not None:
raise AzureOpenAIError(status_code=error_code, message=str(e))
else:
raise AzureOpenAIError(status_code=500, message=str(e))

View file

@ -4,6 +4,7 @@ import traceback
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
import httpx
from httpx import USE_CLIENT_DEFAULT
import litellm
@ -76,9 +77,20 @@ class AsyncHTTPHandler:
await self.client.aclose()
async def get(
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
self,
url: str,
params: Optional[dict] = None,
headers: Optional[dict] = None,
follow_redirects: Optional[bool] = None,
):
response = await self.client.get(url, params=params, headers=headers)
# Set follow_redirects to UseClientDefault if None
_follow_redirects = (
follow_redirects if follow_redirects is not None else USE_CLIENT_DEFAULT
)
response = await self.client.get(
url, params=params, headers=headers, follow_redirects=_follow_redirects # type: ignore
)
return response
async def post(
@ -117,8 +129,9 @@ class AsyncHTTPHandler:
await new_client.aclose()
except httpx.TimeoutException as e:
headers = {}
if hasattr(e, "response") and e.response is not None:
for key, value in e.response.headers.items():
error_response = getattr(e, "response", None)
if error_response is not None:
for key, value in error_response.headers.items():
headers["response_headers-{}".format(key)] = value
raise litellm.Timeout(
@ -173,8 +186,9 @@ class AsyncHTTPHandler:
await new_client.aclose()
except httpx.TimeoutException as e:
headers = {}
if hasattr(e, "response") and e.response is not None:
for key, value in e.response.headers.items():
error_response = getattr(e, "response", None)
if error_response is not None:
for key, value in error_response.headers.items():
headers["response_headers-{}".format(key)] = value
raise litellm.Timeout(
@ -303,9 +317,20 @@ class HTTPHandler:
self.client.close()
def get(
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
self,
url: str,
params: Optional[dict] = None,
headers: Optional[dict] = None,
follow_redirects: Optional[bool] = None,
):
response = self.client.get(url, params=params, headers=headers)
# Set follow_redirects to UseClientDefault if None
_follow_redirects = (
follow_redirects if follow_redirects is not None else USE_CLIENT_DEFAULT
)
response = self.client.get(
url, params=params, headers=headers, follow_redirects=_follow_redirects # type: ignore
)
return response
def post(

View file

@ -244,6 +244,34 @@ class DatabricksChatCompletion(BaseLLM):
# makes headers for API call
def _get_databricks_credentials(
self, api_key: Optional[str], api_base: Optional[str], headers: Optional[dict]
) -> Tuple[str, dict]:
headers = headers or {"Content-Type": "application/json"}
try:
from databricks.sdk import WorkspaceClient
databricks_client = WorkspaceClient()
api_base = api_base or f"{databricks_client.config.host}/serving-endpoints"
if api_key is None:
databricks_auth_headers: dict[str, str] = (
databricks_client.config.authenticate()
)
headers = {**databricks_auth_headers, **headers}
return api_base, headers
except ImportError:
raise DatabricksError(
status_code=400,
message=(
"If the Databricks base URL and API key are not set, the databricks-sdk "
"Python library must be installed. Please install the databricks-sdk, set "
"{LLM_PROVIDER}_API_BASE and {LLM_PROVIDER}_API_KEY environment variables, "
"or provide the base URL and API key as arguments."
),
)
def _validate_environment(
self,
api_key: Optional[str],
@ -253,16 +281,26 @@ class DatabricksChatCompletion(BaseLLM):
headers: Optional[dict],
) -> Tuple[str, dict]:
if api_key is None and headers is None:
if custom_endpoint:
raise DatabricksError(
status_code=400,
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
)
else:
api_base, headers = self._get_databricks_credentials(
api_base=api_base, api_key=api_key, headers=headers
)
if api_base is None:
if custom_endpoint:
raise DatabricksError(
status_code=400,
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
)
else:
api_base, headers = self._get_databricks_credentials(
api_base=api_base, api_key=api_key, headers=headers
)
if headers is None:
headers = {
@ -273,6 +311,9 @@ class DatabricksChatCompletion(BaseLLM):
if api_key is not None:
headers.update({"Authorization": "Bearer {}".format(api_key)})
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
if endpoint_type == "chat_completions" and custom_endpoint is not True:
api_base = "{}/chat/completions".format(api_base)
elif endpoint_type == "embeddings" and custom_endpoint is not True:
@ -520,7 +561,8 @@ class DatabricksChatCompletion(BaseLLM):
response_json = response.json()
except httpx.HTTPStatusError as e:
raise DatabricksError(
status_code=e.response.status_code, message=e.response.text
status_code=e.response.status_code,
message=e.response.text,
)
except httpx.TimeoutException as e:
raise DatabricksError(

View file

@ -1,8 +1,6 @@
import types
from typing import Literal, Optional, Union
import litellm
class FireworksAIConfig:
"""

View file

@ -0,0 +1,72 @@
"""
For calculating cost of fireworks ai serverless inference models.
"""
from typing import Tuple
from litellm.types.utils import Usage
from litellm.utils import get_model_info
# Extract the number of billion parameters from the model name
# only used for together_computer LLMs
def get_model_params_and_category(model_name: str) -> str:
"""
Helper function for calculating together ai pricing.
Returns:
- str: model pricing category if mapped else received model name
"""
import re
model_name = model_name.lower()
# Check for MoE models in the form <number>x<number>b
moe_match = re.search(r"(\d+)x(\d+)b", model_name)
if moe_match:
total_billion = int(moe_match.group(1)) * int(moe_match.group(2))
if total_billion <= 56:
return "fireworks-ai-moe-up-to-56b"
elif total_billion <= 176:
return "fireworks-ai-56b-to-176b"
# Check for standard models in the form <number>b
re_params_match = re.search(r"(\d+)b", model_name)
if re_params_match is not None:
params_match = str(re_params_match.group(1))
params_billion = float(params_match)
# Determine the category based on the number of parameters
if params_billion <= 16.0:
return "fireworks-ai-up-to-16b"
elif params_billion <= 80.0:
return "fireworks-ai-16b-80b"
# If no matches, return the original model_name
return model_name
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
Input:
- model: str, the model name without provider prefix
- usage: LiteLLM Usage block, containing anthropic caching information
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
"""
base_model = get_model_params_and_category(model_name=model)
## GET MODEL INFO
model_info = get_model_info(model=base_model, custom_llm_provider="fireworks_ai")
## CALCULATE INPUT COST
prompt_cost: float = usage["prompt_tokens"] * model_info["input_cost_per_token"]
## CALCULATE OUTPUT COST
completion_cost = usage["completion_tokens"] * model_info["output_cost_per_token"]
return prompt_cost, completion_cost

View file

@ -7,6 +7,7 @@ import base64
from httpx import Response
import litellm
from litellm import verbose_logger
from litellm.caching import InMemoryCache
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
@ -58,7 +59,7 @@ async def async_convert_url_to_base64(url: str) -> str:
client = litellm.module_level_aclient
for _ in range(3):
try:
response = await client.get(url)
response = await client.get(url, follow_redirects=True)
return _process_image_response(response, url)
except:
pass
@ -75,9 +76,11 @@ def convert_url_to_base64(url: str) -> str:
client = litellm.module_level_client
for _ in range(3):
try:
response = client.get(url)
response = client.get(url, follow_redirects=True)
return _process_image_response(response, url)
except:
except Exception as e:
verbose_logger.exception(e)
# print(e)
pass
raise Exception(
f"Error: Unable to fetch image from URL after 3 attempts. url={url}"

View file

@ -5452,6 +5452,26 @@
"mode": "chat",
"supports_function_calling": true,
"source": "https://fireworks.ai/pricing"
},
"fireworks-ai-up-to-16b": {
"input_cost_per_token": 0.0000002,
"output_cost_per_token": 0.0000002,
"litellm_provider": "fireworks_ai"
},
"fireworks-ai-16.1b-to-80b": {
"input_cost_per_token": 0.0000009,
"output_cost_per_token": 0.0000009,
"litellm_provider": "fireworks_ai"
},
"fireworks-ai-moe-up-to-56b": {
"input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000005,
"litellm_provider": "fireworks_ai"
},
"fireworks-ai-56b-to-176b": {
"input_cost_per_token": 0.0000012,
"output_cost_per_token": 0.0000012,
"litellm_provider": "fireworks_ai"
},
"anyscale/mistralai/Mistral-7B-Instruct-v0.1": {
"max_tokens": 16384,

View file

@ -20,3 +20,10 @@ model_list:
litellm_params:
model: o1-preview
litellm_settings:
cache: true
# cache_params:
# type: "redis"
# service_name: "mymaster"
# sentinel_nodes:
# - ["localhost", 26379]

View file

@ -653,27 +653,12 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
return
try:
from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets import SecretClient
# Set your Azure Key Vault URI
KVUri = os.getenv("AZURE_KEY_VAULT_URI", None)
# Set your Azure AD application/client ID, client secret, and tenant ID
client_id = os.getenv("AZURE_CLIENT_ID", None)
client_secret = os.getenv("AZURE_CLIENT_SECRET", None)
tenant_id = os.getenv("AZURE_TENANT_ID", None)
if (
KVUri is not None
and client_id is not None
and client_secret is not None
and tenant_id is not None
):
# Initialize the ClientSecretCredential
# credential = ClientSecretCredential(
# client_id=client_id, client_secret=client_secret, tenant_id=tenant_id
# )
credential = DefaultAzureCredential()
# Create the SecretClient using the credential
@ -681,10 +666,6 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
litellm.secret_manager_client = client
litellm._key_management_system = KeyManagementSystem.AZURE_KEY_VAULT
else:
raise Exception(
f"Missing KVUri or client_id or client_secret or tenant_id from environment"
)
except Exception as e:
_error_str = str(e)
verbose_proxy_logger.exception(
@ -1626,8 +1607,8 @@ class ProxyConfig:
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
self._init_cache(cache_params=cache_params)
if litellm.cache is not None:
verbose_proxy_logger.debug( # noqa
f"{blue_color_code}Set Cache on LiteLLM Proxy= {vars(litellm.cache.cache)}{vars(litellm.cache)}{reset_color_code}"
verbose_proxy_logger.debug(
f"{blue_color_code}Set Cache on LiteLLM Proxy{reset_color_code}"
)
elif key == "cache" and value is False:
pass

View file

@ -4019,7 +4019,9 @@ class Router:
_model_info=_model_info,
)
verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}")
verbose_router_logger.debug(
f"\nInitialized Model List {self.get_model_names()}"
)
self.model_names = [m["model_name"] for m in model_list]
def _add_deployment(self, deployment: Deployment) -> Deployment:
@ -4630,6 +4632,7 @@ class Router:
if hasattr(self, "model_list"):
returned_models: List[DeploymentTypedDict] = []
if hasattr(self, "model_group_alias"):
for model_alias, model_value in self.model_group_alias.items():
if isinstance(model_value, str):
@ -5030,7 +5033,7 @@ class Router:
# return the first deployment where the `model` matches the specificed deployment name
return deployment_model, deployment
raise ValueError(
f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}"
f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.get_model_names()}"
)
elif model in self.get_model_ids():
deployment = self.get_model_info(id=model)

View file

@ -2045,3 +2045,57 @@ async def test_proxy_logging_setup():
pl_obj = ProxyLogging(user_api_key_cache=DualCache())
assert pl_obj.internal_usage_cache.always_read_redis is True
@pytest.mark.skip(reason="local test. Requires sentinel setup.")
@pytest.mark.asyncio
async def test_redis_sentinel_caching():
"""
Init redis client
- write to client
- read from client
"""
litellm.set_verbose = False
random_number = random.randint(
1, 100000
) # add a random number to ensure it's always adding / reading from cache
messages = [
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
]
litellm.cache = Cache(
type="redis",
# host=os.environ["REDIS_HOST"],
# port=os.environ["REDIS_PORT"],
# password=os.environ["REDIS_PASSWORD"],
service_name="mymaster",
sentinel_nodes=[("localhost", 26379)],
)
response1 = completion(
model="gpt-3.5-turbo",
messages=messages,
)
cache_key = litellm.cache.get_cache_key(
model="gpt-3.5-turbo",
messages=messages,
)
print(f"cache_key: {cache_key}")
litellm.cache.add_cache(result=response1, cache_key=cache_key)
print(f"cache key pre async get: {cache_key}")
stored_val = litellm.cache.get_cache(
model="gpt-3.5-turbo",
messages=messages,
)
print(f"stored_val: {stored_val}")
assert stored_val["id"] == response1.id
stored_val_2 = await litellm.cache.async_get_cache(
model="gpt-3.5-turbo",
messages=messages,
)
print(f"stored_val: {stored_val}")
assert stored_val_2["id"] == response1.id

View file

@ -1255,3 +1255,16 @@ def test_completion_cost_databricks_embedding(model):
print(resp)
cost = completion_cost(completion_response=resp)
def test_completion_cost_fireworks_ai():
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
messages = [{"role": "user", "content": "Hey, how's it going?"}]
resp = litellm.completion(
model="fireworks_ai/mixtral-8x7b-instruct", messages=messages
) # works fine
print(resp)
cost = completion_cost(completion_response=resp)

View file

@ -432,3 +432,7 @@ def test_vertex_only_image_user_message():
), "Invalid gemini input. Got={}, Expected={}".format(
content, expected_response[idx]
)
def test_convert_url():
convert_url_to_base64("https://picsum.photos/id/237/200/300")

View file

@ -6830,7 +6830,10 @@ def exception_type(
llm_provider=custom_llm_provider,
model=model,
)
elif original_exception.status_code == 401:
elif (
original_exception.status_code == 401
or original_exception.status_code == 403
):
exception_mapping_worked = True
raise AuthenticationError(
message=f"{custom_llm_provider}Exception - {original_exception.message}",

View file

@ -5452,6 +5452,26 @@
"mode": "chat",
"supports_function_calling": true,
"source": "https://fireworks.ai/pricing"
},
"fireworks-ai-up-to-16b": {
"input_cost_per_token": 0.0000002,
"output_cost_per_token": 0.0000002,
"litellm_provider": "fireworks_ai"
},
"fireworks-ai-16.1b-to-80b": {
"input_cost_per_token": 0.0000009,
"output_cost_per_token": 0.0000009,
"litellm_provider": "fireworks_ai"
},
"fireworks-ai-moe-up-to-56b": {
"input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000005,
"litellm_provider": "fireworks_ai"
},
"fireworks-ai-56b-to-176b": {
"input_cost_per_token": 0.0000012,
"output_cost_per_token": 0.0000012,
"litellm_provider": "fireworks_ai"
},
"anyscale/mistralai/Mistral-7B-Instruct-v0.1": {
"max_tokens": 16384,

View file

@ -7,10 +7,17 @@ from typing import Any, Dict, List
from unittest.mock import MagicMock, Mock, patch
import litellm
from litellm.exceptions import BadRequestError, InternalServerError
from litellm.exceptions import BadRequestError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import CustomStreamWrapper
try:
import databricks.sdk
databricks_sdk_installed = True
except ImportError:
databricks_sdk_installed = False
def mock_chat_response() -> Dict[str, Any]:
return {
@ -33,8 +40,8 @@ def mock_chat_response() -> Dict[str, Any]:
"usage": {
"prompt_tokens": 230,
"completion_tokens": 38,
"total_tokens": 268,
"completion_tokens_details": None,
"total_tokens": 268,
},
"system_fingerprint": None,
}
@ -195,7 +202,14 @@ def mock_embedding_response() -> Dict[str, Any]:
@pytest.mark.parametrize("set_base", [True, False])
def test_throws_if_only_one_of_api_base_or_api_key_set(monkeypatch, set_base):
def test_throws_if_api_base_or_api_key_not_set_without_databricks_sdk(
monkeypatch, set_base
):
# Simulate that the databricks SDK is not installed
monkeypatch.setitem(sys.modules, "databricks.sdk", None)
err_msg = "the Databricks base URL and API key are not set"
if set_base:
monkeypatch.setenv(
"DATABRICKS_API_BASE",
@ -204,11 +218,11 @@ def test_throws_if_only_one_of_api_base_or_api_key_set(monkeypatch, set_base):
monkeypatch.delenv(
"DATABRICKS_API_KEY",
)
err_msg = "A call is being made to LLM Provider but no key is set"
else:
monkeypatch.setenv("DATABRICKS_API_KEY", "dapimykey")
monkeypatch.delenv("DATABRICKS_API_BASE")
err_msg = "A call is being made to LLM Provider but no api base is set"
monkeypatch.delenv(
"DATABRICKS_API_BASE",
)
with pytest.raises(BadRequestError) as exc:
litellm.completion(
@ -422,6 +436,67 @@ def test_completions_streaming_with_async_http_handler(monkeypatch):
)
@pytest.mark.skipif(not databricks_sdk_installed, reason="Databricks SDK not installed")
def test_completions_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch):
from databricks.sdk import WorkspaceClient
from databricks.sdk.config import Config
sync_handler = HTTPHandler()
mock_response = Mock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = mock_chat_response()
expected_response_json = {
**mock_chat_response(),
**{
"model": "databricks/dbrx-instruct-071224",
},
}
base_url = "https://my.workspace.cloud.databricks.com"
api_key = "dapimykey"
headers = {
"Authorization": f"Bearer {api_key}",
}
messages = [{"role": "user", "content": "How are you?"}]
mock_workspace_client: WorkspaceClient = MagicMock()
mock_config: Config = MagicMock()
# Simulate the behavior of the config property and its methods
mock_config.authenticate.side_effect = lambda: headers
mock_config.host = base_url # Assign directly as if it's a property
mock_workspace_client.config = mock_config
with patch(
"databricks.sdk.WorkspaceClient", return_value=mock_workspace_client
), patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
response = litellm.completion(
model="databricks/dbrx-instruct-071224",
messages=messages,
client=sync_handler,
temperature=0.5,
extraparam="testpassingextraparam",
)
assert response.to_dict() == expected_response_json
mock_post.assert_called_once_with(
f"{base_url}/serving-endpoints/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
data=json.dumps(
{
"model": "dbrx-instruct-071224",
"messages": messages,
"temperature": 0.5,
"extraparam": "testpassingextraparam",
"stream": False,
}
),
)
def test_embeddings_with_sync_http_handler(monkeypatch):
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
api_key = "dapimykey"
@ -500,3 +575,59 @@ def test_embeddings_with_async_http_handler(monkeypatch):
}
),
)
@pytest.mark.skipif(not databricks_sdk_installed, reason="Databricks SDK not installed")
def test_embeddings_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch):
from databricks.sdk import WorkspaceClient
from databricks.sdk.config import Config
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
api_key = "dapimykey"
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
sync_handler = HTTPHandler()
mock_response = Mock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = mock_embedding_response()
base_url = "https://my.workspace.cloud.databricks.com"
api_key = "dapimykey"
headers = {
"Authorization": f"Bearer {api_key}",
}
inputs = ["Hello", "World"]
mock_workspace_client: WorkspaceClient = MagicMock()
mock_config: Config = MagicMock()
# Simulate the behavior of the config property and its methods
mock_config.authenticate.side_effect = lambda: headers
mock_config.host = base_url # Assign directly as if it's a property
mock_workspace_client.config = mock_config
with patch(
"databricks.sdk.WorkspaceClient", return_value=mock_workspace_client
), patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
response = litellm.embedding(
model="databricks/bge-large-en-v1.5",
input=inputs,
client=sync_handler,
extraparam="testpassingextraparam",
)
assert response.to_dict() == mock_embedding_response()
mock_post.assert_called_once_with(
f"{base_url}/serving-endpoints/embeddings",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
data=json.dumps(
{
"model": "bge-large-en-v1.5",
"input": inputs,
"extraparam": "testpassingextraparam",
}
),
)

View file

@ -7,7 +7,7 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from litellm.llms.fireworks_ai import FireworksAIConfig
from litellm.llms.fireworks_ai.chat.fireworks_ai_transformation import FireworksAIConfig
fireworks = FireworksAIConfig()

View file

@ -149,7 +149,9 @@ def test_all_model_configs():
{"max_completion_tokens": 10}, {}, "llama3"
) == {"max_tokens": 10}
from litellm.llms.fireworks_ai import FireworksAIConfig
from litellm.llms.fireworks_ai.chat.fireworks_ai_transformation import (
FireworksAIConfig,
)
assert "max_completion_tokens" in FireworksAIConfig().get_supported_openai_params()
assert FireworksAIConfig().map_openai_params(