forked from phoenix/litellm-mirror
LiteLLM Minor Fixes & Improvements (09/18/2024) (#5772)
* fix(proxy_server.py): fix azure key vault logic to not require client id/secret * feat(cost_calculator.py): support fireworks ai cost tracking * build(docker-compose.yml): add lines for mounting config.yaml to docker compose Closes https://github.com/BerriAI/litellm/issues/5739 * fix(input.md): update docs to clarify litellm supports content as a list of dictionaries Fixes https://github.com/BerriAI/litellm/issues/5755 * fix(input.md): update input.md to include all message values * fix(image_handling.py): follow image url redirects Fixes https://github.com/BerriAI/litellm/issues/5763 * fix(router.py): Fix model key/base leak in error message Fixes https://github.com/BerriAI/litellm/issues/5762 * fix(http_handler.py): fix linting error * fix(azure.py): fix logging to show azure_ad_token being used Fixes https://github.com/BerriAI/litellm/issues/5767 * fix(_redis.py): add redis sentinel support Closes https://github.com/BerriAI/litellm/issues/4381 * feat(_redis.py): add redis sentinel support Closes https://github.com/BerriAI/litellm/issues/4381 * test(test_completion_cost.py): fix test * Databricks Integration: Integrate Databricks SDK as optional mechanism for fetching API base and token, if unspecified (#5746) * LiteLLM Minor Fixes & Improvements (09/16/2024) (#5723) * coverage (#5713) Signed-off-by: dbczumar <corey.zumar@databricks.com> * Move (#5714) Signed-off-by: dbczumar <corey.zumar@databricks.com> * fix(litellm_logging.py): fix logging client re-init (#5710) Fixes https://github.com/BerriAI/litellm/issues/5695 * fix(presidio.py): Fix logging_hook response and add support for additional presidio variables in guardrails config Fixes https://github.com/BerriAI/litellm/issues/5682 * feat(o1_handler.py): fake streaming for openai o1 models Fixes https://github.com/BerriAI/litellm/issues/5694 * docs: deprecated traceloop integration in favor of native otel (#5249) * fix: fix linting errors * fix: fix linting errors * fix(main.py): fix o1 import --------- Signed-off-by: dbczumar <corey.zumar@databricks.com> Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com> Co-authored-by: Nir Gazit <nirga@users.noreply.github.com> * feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view (#5730) * feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view Supports having `MonthlyGlobalSpend` view be a material view, and exposes an endpoint to refresh it * fix(custom_logger.py): reset calltype * fix: fix linting errors * fix: fix linting error * fix Signed-off-by: dbczumar <corey.zumar@databricks.com> * fix: fix import * Fix Signed-off-by: dbczumar <corey.zumar@databricks.com> * fix Signed-off-by: dbczumar <corey.zumar@databricks.com> * DB test Signed-off-by: dbczumar <corey.zumar@databricks.com> * Coverage Signed-off-by: dbczumar <corey.zumar@databricks.com> * progress Signed-off-by: dbczumar <corey.zumar@databricks.com> * fix Signed-off-by: dbczumar <corey.zumar@databricks.com> * fix Signed-off-by: dbczumar <corey.zumar@databricks.com> * fix Signed-off-by: dbczumar <corey.zumar@databricks.com> * fix test name Signed-off-by: dbczumar <corey.zumar@databricks.com> --------- Signed-off-by: dbczumar <corey.zumar@databricks.com> Co-authored-by: Krish Dholakia <krrishdholakia@gmail.com> Co-authored-by: Nir Gazit <nirga@users.noreply.github.com> * test: fix test * test(test_databricks.py): fix test * fix(databricks/chat.py): handle custom endpoint (e.g. sagemaker) * Apply code scanning fix for clear-text logging of sensitive information Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> * fix(__init__.py): fix known fireworks ai models --------- Signed-off-by: dbczumar <corey.zumar@databricks.com> Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com> Co-authored-by: Nir Gazit <nirga@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
This commit is contained in:
parent
49b2766723
commit
d46660ea0f
24 changed files with 697 additions and 170 deletions
|
@ -9,6 +9,14 @@ services:
|
|||
#########################################
|
||||
## Uncomment these lines to start proxy with a config.yaml file ##
|
||||
# volumes:
|
||||
# - ./config.yaml:/app/config.yaml <<- this is missing in the docker-compose file currently
|
||||
# The below two are my suggestion
|
||||
# command:
|
||||
# - "--config=/app/config.yaml"
|
||||
##############################################
|
||||
#########################################
|
||||
## Uncomment these lines to start proxy with a config.yaml file ##
|
||||
# volumes:
|
||||
###############################################
|
||||
ports:
|
||||
- "4000:4000" # Map the container port to the host, change the host port if necessary
|
||||
|
|
|
@ -124,16 +124,19 @@ def completion(
|
|||
#### Properties of `messages`
|
||||
*Note* - Each message in the array contains the following properties:
|
||||
|
||||
- `role`: *string* - The role of the message's author. Roles can be: system, user, assistant, or function.
|
||||
- `role`: *string* - The role of the message's author. Roles can be: system, user, assistant, function or tool.
|
||||
|
||||
- `content`: *string or null* - The contents of the message. It is required for all messages, but may be null for assistant messages with function calls.
|
||||
- `content`: *string or list[dict] or null* - The contents of the message. It is required for all messages, but may be null for assistant messages with function calls.
|
||||
|
||||
- `name`: *string (optional)* - The name of the author of the message. It is required if the role is "function". The name should match the name of the function represented in the content. It can contain characters (a-z, A-Z, 0-9), and underscores, with a maximum length of 64 characters.
|
||||
|
||||
- `function_call`: *object (optional)* - The name and arguments of a function that should be called, as generated by the model.
|
||||
|
||||
- `tool_call_id`: *str (optional)* - Tool call that this message is responding to.
|
||||
|
||||
|
||||
[**See All Message Values**](https://github.com/BerriAI/litellm/blob/8600ec77042dacad324d3879a2bd918fc6a719fa/litellm/types/llms/openai.py#L392)
|
||||
|
||||
## Optional Fields
|
||||
|
||||
- `temperature`: *number or null (optional)* - The sampling temperature to be used, between 0 and 2. Higher values like 0.8 produce more random outputs, while lower values like 0.2 make outputs more focused and deterministic.
|
||||
|
|
|
@ -110,6 +110,60 @@ print("REDIS_CLUSTER_NODES", os.environ["REDIS_CLUSTER_NODES"])
|
|||
|
||||
</Tabs>
|
||||
|
||||
#### Redis Sentinel
|
||||
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="redis-sentinel-config" label="Set on config.yaml">
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: "*"
|
||||
litellm_params:
|
||||
model: "*"
|
||||
|
||||
|
||||
litellm_settings:
|
||||
cache: true
|
||||
cache_params:
|
||||
type: "redis"
|
||||
service_name: "mymaster"
|
||||
sentinel_nodes: [["localhost", 26379]]
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="redis-env" label="Set on .env">
|
||||
|
||||
You can configure redis sentinel in your .env by setting `REDIS_SENTINEL_NODES` in your .env
|
||||
|
||||
**Example `REDIS_SENTINEL_NODES`** value
|
||||
|
||||
```env
|
||||
REDIS_SENTINEL_NODES='[["localhost", 26379]]'
|
||||
REDIS_SERVICE_NAME = "mymaster"
|
||||
```
|
||||
|
||||
:::note
|
||||
|
||||
Example python script for setting redis cluster nodes in .env:
|
||||
|
||||
```python
|
||||
# List of startup nodes
|
||||
sentinel_nodes = [["localhost", 26379]]
|
||||
|
||||
# set startup nodes in environment variables
|
||||
os.environ["REDIS_SENTINEL_NODES"] = json.dumps(sentinel_nodes)
|
||||
print("REDIS_SENTINEL_NODES", os.environ["REDIS_SENTINEL_NODES"])
|
||||
```
|
||||
|
||||
:::
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
#### TTL
|
||||
|
||||
```yaml
|
||||
|
|
|
@ -382,7 +382,10 @@ deepinfra_models: List = []
|
|||
perplexity_models: List = []
|
||||
watsonx_models: List = []
|
||||
gemini_models: List = []
|
||||
for key, value in model_cost.items():
|
||||
|
||||
|
||||
def add_known_models():
|
||||
for key, value in model_cost.items():
|
||||
if value.get("litellm_provider") == "openai":
|
||||
open_ai_chat_completion_models.append(key)
|
||||
elif value.get("litellm_provider") == "text-completion-openai":
|
||||
|
@ -448,7 +451,12 @@ for key, value in model_cost.items():
|
|||
elif value.get("litellm_provider") == "gemini":
|
||||
gemini_models.append(key)
|
||||
elif value.get("litellm_provider") == "fireworks_ai":
|
||||
# ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
|
||||
if "-to-" not in key:
|
||||
fireworks_ai_models.append(key)
|
||||
|
||||
|
||||
add_known_models()
|
||||
# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
|
||||
openai_compatible_endpoints: List = [
|
||||
"api.perplexity.ai",
|
||||
|
@ -960,7 +968,7 @@ from .llms.nvidia_nim import NvidiaNimConfig
|
|||
from .llms.cerebras.chat import CerebrasConfig
|
||||
from .llms.sambanova.chat import SambanovaConfig
|
||||
from .llms.AI21.chat import AI21ChatConfig
|
||||
from .llms.fireworks_ai import FireworksAIConfig
|
||||
from .llms.fireworks_ai.chat.fireworks_ai_transformation import FireworksAIConfig
|
||||
from .llms.volcengine import VolcEngineConfig
|
||||
from .llms.text_completion_codestral import MistralTextCompletionConfig
|
||||
from .llms.AzureOpenAI.azure import (
|
||||
|
|
|
@ -12,12 +12,13 @@ import json
|
|||
|
||||
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import redis # type: ignore
|
||||
import redis.asyncio as async_redis # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
|
||||
from ._logging import verbose_logger
|
||||
|
||||
|
@ -83,7 +84,7 @@ def _redis_kwargs_from_environment():
|
|||
|
||||
return_dict = {}
|
||||
for k, v in mapping.items():
|
||||
value = litellm.get_secret(k, default_value=None) # check os.environ/key vault
|
||||
value = get_secret(k, default_value=None) # type: ignore
|
||||
if value is not None:
|
||||
return_dict[v] = value
|
||||
return return_dict
|
||||
|
@ -116,7 +117,7 @@ def _get_redis_client_logic(**env_overrides):
|
|||
for k, v in env_overrides.items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
v = v.replace("os.environ/", "")
|
||||
value = litellm.get_secret(v)
|
||||
value = get_secret(v) # type: ignore
|
||||
env_overrides[k] = value
|
||||
|
||||
redis_kwargs = {
|
||||
|
@ -124,13 +125,27 @@ def _get_redis_client_logic(**env_overrides):
|
|||
**env_overrides,
|
||||
}
|
||||
|
||||
_startup_nodes = redis_kwargs.get("startup_nodes", None) or litellm.get_secret(
|
||||
_startup_nodes: Optional[Union[str, list]] = redis_kwargs.get("startup_nodes", None) or get_secret( # type: ignore
|
||||
"REDIS_CLUSTER_NODES"
|
||||
)
|
||||
|
||||
if _startup_nodes is not None:
|
||||
if _startup_nodes is not None and isinstance(_startup_nodes, str):
|
||||
redis_kwargs["startup_nodes"] = json.loads(_startup_nodes)
|
||||
|
||||
_sentinel_nodes: Optional[Union[str, list]] = redis_kwargs.get("sentinel_nodes", None) or get_secret( # type: ignore
|
||||
"REDIS_SENTINEL_NODES"
|
||||
)
|
||||
|
||||
if _sentinel_nodes is not None and isinstance(_sentinel_nodes, str):
|
||||
redis_kwargs["sentinel_nodes"] = json.loads(_sentinel_nodes)
|
||||
|
||||
_service_name: Optional[str] = redis_kwargs.get("service_name", None) or get_secret( # type: ignore
|
||||
"REDIS_SERVICE_NAME"
|
||||
)
|
||||
|
||||
if _service_name is not None:
|
||||
redis_kwargs["service_name"] = _service_name
|
||||
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
redis_kwargs.pop("host", None)
|
||||
redis_kwargs.pop("port", None)
|
||||
|
@ -138,14 +153,19 @@ def _get_redis_client_logic(**env_overrides):
|
|||
redis_kwargs.pop("password", None)
|
||||
elif "startup_nodes" in redis_kwargs and redis_kwargs["startup_nodes"] is not None:
|
||||
pass
|
||||
elif (
|
||||
"sentinel_nodes" in redis_kwargs and redis_kwargs["sentinel_nodes"] is not None
|
||||
):
|
||||
pass
|
||||
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
|
||||
raise ValueError("Either 'host' or 'url' must be specified for redis.")
|
||||
|
||||
# litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
|
||||
return redis_kwargs
|
||||
|
||||
|
||||
def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
|
||||
_redis_cluster_nodes_in_env = litellm.get_secret("REDIS_CLUSTER_NODES")
|
||||
_redis_cluster_nodes_in_env: Optional[str] = get_secret("REDIS_CLUSTER_NODES") # type: ignore
|
||||
if _redis_cluster_nodes_in_env is not None:
|
||||
try:
|
||||
redis_kwargs["startup_nodes"] = json.loads(_redis_cluster_nodes_in_env)
|
||||
|
@ -174,6 +194,44 @@ def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
|
|||
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs)
|
||||
|
||||
|
||||
def _init_redis_sentinel(redis_kwargs) -> redis.Redis:
|
||||
sentinel_nodes = redis_kwargs.get("sentinel_nodes")
|
||||
service_name = redis_kwargs.get("service_name")
|
||||
|
||||
if not sentinel_nodes or not service_name:
|
||||
raise ValueError(
|
||||
"Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel."
|
||||
)
|
||||
|
||||
verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.")
|
||||
|
||||
# Set up the Sentinel client
|
||||
sentinel = redis.Sentinel(sentinel_nodes, socket_timeout=0.1)
|
||||
|
||||
# Return the master instance for the given service
|
||||
|
||||
return sentinel.master_for(service_name)
|
||||
|
||||
|
||||
def _init_async_redis_sentinel(redis_kwargs) -> async_redis.Redis:
|
||||
sentinel_nodes = redis_kwargs.get("sentinel_nodes")
|
||||
service_name = redis_kwargs.get("service_name")
|
||||
|
||||
if not sentinel_nodes or not service_name:
|
||||
raise ValueError(
|
||||
"Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel."
|
||||
)
|
||||
|
||||
verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.")
|
||||
|
||||
# Set up the Sentinel client
|
||||
sentinel = async_redis.Sentinel(sentinel_nodes, socket_timeout=0.1)
|
||||
|
||||
# Return the master instance for the given service
|
||||
|
||||
return sentinel.master_for(service_name)
|
||||
|
||||
|
||||
def get_redis_client(**env_overrides):
|
||||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
|
@ -185,12 +243,13 @@ def get_redis_client(**env_overrides):
|
|||
|
||||
return redis.Redis.from_url(**url_kwargs)
|
||||
|
||||
if (
|
||||
"startup_nodes" in redis_kwargs
|
||||
or litellm.get_secret("REDIS_CLUSTER_NODES") is not None
|
||||
):
|
||||
if "startup_nodes" in redis_kwargs or get_secret("REDIS_CLUSTER_NODES") is not None: # type: ignore
|
||||
return init_redis_cluster(redis_kwargs)
|
||||
|
||||
# Check for Redis Sentinel
|
||||
if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs:
|
||||
return _init_redis_sentinel(redis_kwargs)
|
||||
|
||||
return redis.Redis(**redis_kwargs)
|
||||
|
||||
|
||||
|
@ -203,7 +262,7 @@ def get_redis_async_client(**env_overrides):
|
|||
if arg in args:
|
||||
url_kwargs[arg] = redis_kwargs[arg]
|
||||
else:
|
||||
litellm.print_verbose(
|
||||
verbose_logger.debug(
|
||||
"REDIS: ignoring argument: {}. Not an allowed async_redis.Redis.from_url arg.".format(
|
||||
arg
|
||||
)
|
||||
|
@ -225,9 +284,13 @@ def get_redis_async_client(**env_overrides):
|
|||
new_startup_nodes.append(ClusterNode(**item))
|
||||
redis_kwargs.pop("startup_nodes")
|
||||
return async_redis.RedisCluster(
|
||||
startup_nodes=new_startup_nodes, **cluster_kwargs
|
||||
startup_nodes=new_startup_nodes, **cluster_kwargs # type: ignore
|
||||
)
|
||||
|
||||
# Check for Redis Sentinel
|
||||
if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs:
|
||||
return _init_async_redis_sentinel(redis_kwargs)
|
||||
|
||||
return async_redis.Redis(
|
||||
socket_timeout=5,
|
||||
**redis_kwargs,
|
||||
|
|
|
@ -25,6 +25,9 @@ from litellm.llms.anthropic.cost_calculation import (
|
|||
from litellm.llms.databricks.cost_calculator import (
|
||||
cost_per_token as databricks_cost_per_token,
|
||||
)
|
||||
from litellm.llms.fireworks_ai.cost_calculator import (
|
||||
cost_per_token as fireworks_ai_cost_per_token,
|
||||
)
|
||||
from litellm.rerank_api.types import RerankResponse
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
|
||||
|
@ -217,6 +220,8 @@ def cost_per_token(
|
|||
return anthropic_cost_per_token(model=model, usage=usage_block)
|
||||
elif custom_llm_provider == "databricks":
|
||||
return databricks_cost_per_token(model=model, usage=usage_block)
|
||||
elif custom_llm_provider == "fireworks_ai":
|
||||
return fireworks_ai_cost_per_token(model=model, usage=usage_block)
|
||||
elif custom_llm_provider == "gemini":
|
||||
return google_cost_per_token(
|
||||
model=model_without_prefix,
|
||||
|
|
|
@ -13,18 +13,14 @@ from pydantic import BaseModel
|
|||
from typing_extensions import overload
|
||||
|
||||
import litellm
|
||||
from litellm import ImageResponse, OpenAIConfig
|
||||
from litellm.caching import DualCache
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.utils import FileTypes # type: ignore
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
from litellm.utils import (
|
||||
Choices,
|
||||
CustomStreamWrapper,
|
||||
Message,
|
||||
ModelResponse,
|
||||
TranscriptionResponse,
|
||||
UnsupportedParamsError,
|
||||
convert_to_model_response_object,
|
||||
get_secret,
|
||||
|
@ -674,7 +670,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
logging_obj=logging_obj,
|
||||
convert_tool_call_to_json_mode=json_mode,
|
||||
)
|
||||
elif "stream" in optional_params and optional_params["stream"] == True:
|
||||
elif "stream" in optional_params and optional_params["stream"] is True:
|
||||
return self.streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
|
@ -725,7 +721,11 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
if client is None or dynamic_params:
|
||||
if (
|
||||
client is None
|
||||
or not isinstance(client, AzureOpenAI)
|
||||
or dynamic_params
|
||||
):
|
||||
azure_client = AzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
|
@ -824,7 +824,10 @@ class AzureChatCompletion(BaseLLM):
|
|||
input=data["messages"],
|
||||
api_key=azure_client.api_key,
|
||||
additional_args={
|
||||
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
|
||||
"headers": {
|
||||
"api_key": api_key,
|
||||
"azure_ad_token": azure_ad_token,
|
||||
},
|
||||
"api_base": azure_client._base_url._uri_reference,
|
||||
"acompletion": True,
|
||||
"complete_input_dict": data,
|
||||
|
@ -930,7 +933,10 @@ class AzureChatCompletion(BaseLLM):
|
|||
input=data["messages"],
|
||||
api_key=azure_client.api_key,
|
||||
additional_args={
|
||||
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
|
||||
"headers": {
|
||||
"api_key": api_key,
|
||||
"azure_ad_token": azure_ad_token,
|
||||
},
|
||||
"api_base": azure_client._base_url._uri_reference,
|
||||
"acompletion": True,
|
||||
"complete_input_dict": data,
|
||||
|
@ -988,7 +994,10 @@ class AzureChatCompletion(BaseLLM):
|
|||
input=data["messages"],
|
||||
api_key=azure_client.api_key,
|
||||
additional_args={
|
||||
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
|
||||
"headers": {
|
||||
"api_key": api_key,
|
||||
"azure_ad_token": azure_ad_token,
|
||||
},
|
||||
"api_base": azure_client._base_url._uri_reference,
|
||||
"acompletion": True,
|
||||
"complete_input_dict": data,
|
||||
|
@ -1567,12 +1576,11 @@ class AzureChatCompletion(BaseLLM):
|
|||
# return response
|
||||
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
|
||||
except AzureOpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
except Exception as e:
|
||||
if hasattr(e, "status_code"):
|
||||
_status_code = getattr(e, "status_code")
|
||||
raise AzureOpenAIError(status_code=_status_code, message=str(e))
|
||||
error_code = getattr(e, "status_code", None)
|
||||
if error_code is not None:
|
||||
raise AzureOpenAIError(status_code=error_code, message=str(e))
|
||||
else:
|
||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import traceback
|
|||
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
|
||||
|
||||
import httpx
|
||||
from httpx import USE_CLIENT_DEFAULT
|
||||
|
||||
import litellm
|
||||
|
||||
|
@ -76,9 +77,20 @@ class AsyncHTTPHandler:
|
|||
await self.client.aclose()
|
||||
|
||||
async def get(
|
||||
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
|
||||
self,
|
||||
url: str,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
follow_redirects: Optional[bool] = None,
|
||||
):
|
||||
response = await self.client.get(url, params=params, headers=headers)
|
||||
# Set follow_redirects to UseClientDefault if None
|
||||
_follow_redirects = (
|
||||
follow_redirects if follow_redirects is not None else USE_CLIENT_DEFAULT
|
||||
)
|
||||
|
||||
response = await self.client.get(
|
||||
url, params=params, headers=headers, follow_redirects=_follow_redirects # type: ignore
|
||||
)
|
||||
return response
|
||||
|
||||
async def post(
|
||||
|
@ -117,8 +129,9 @@ class AsyncHTTPHandler:
|
|||
await new_client.aclose()
|
||||
except httpx.TimeoutException as e:
|
||||
headers = {}
|
||||
if hasattr(e, "response") and e.response is not None:
|
||||
for key, value in e.response.headers.items():
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_response is not None:
|
||||
for key, value in error_response.headers.items():
|
||||
headers["response_headers-{}".format(key)] = value
|
||||
|
||||
raise litellm.Timeout(
|
||||
|
@ -173,8 +186,9 @@ class AsyncHTTPHandler:
|
|||
await new_client.aclose()
|
||||
except httpx.TimeoutException as e:
|
||||
headers = {}
|
||||
if hasattr(e, "response") and e.response is not None:
|
||||
for key, value in e.response.headers.items():
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_response is not None:
|
||||
for key, value in error_response.headers.items():
|
||||
headers["response_headers-{}".format(key)] = value
|
||||
|
||||
raise litellm.Timeout(
|
||||
|
@ -303,9 +317,20 @@ class HTTPHandler:
|
|||
self.client.close()
|
||||
|
||||
def get(
|
||||
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
|
||||
self,
|
||||
url: str,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
follow_redirects: Optional[bool] = None,
|
||||
):
|
||||
response = self.client.get(url, params=params, headers=headers)
|
||||
# Set follow_redirects to UseClientDefault if None
|
||||
_follow_redirects = (
|
||||
follow_redirects if follow_redirects is not None else USE_CLIENT_DEFAULT
|
||||
)
|
||||
|
||||
response = self.client.get(
|
||||
url, params=params, headers=headers, follow_redirects=_follow_redirects # type: ignore
|
||||
)
|
||||
return response
|
||||
|
||||
def post(
|
||||
|
|
|
@ -244,6 +244,34 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
|
||||
# makes headers for API call
|
||||
|
||||
def _get_databricks_credentials(
|
||||
self, api_key: Optional[str], api_base: Optional[str], headers: Optional[dict]
|
||||
) -> Tuple[str, dict]:
|
||||
headers = headers or {"Content-Type": "application/json"}
|
||||
try:
|
||||
from databricks.sdk import WorkspaceClient
|
||||
|
||||
databricks_client = WorkspaceClient()
|
||||
api_base = api_base or f"{databricks_client.config.host}/serving-endpoints"
|
||||
|
||||
if api_key is None:
|
||||
databricks_auth_headers: dict[str, str] = (
|
||||
databricks_client.config.authenticate()
|
||||
)
|
||||
headers = {**databricks_auth_headers, **headers}
|
||||
|
||||
return api_base, headers
|
||||
except ImportError:
|
||||
raise DatabricksError(
|
||||
status_code=400,
|
||||
message=(
|
||||
"If the Databricks base URL and API key are not set, the databricks-sdk "
|
||||
"Python library must be installed. Please install the databricks-sdk, set "
|
||||
"{LLM_PROVIDER}_API_BASE and {LLM_PROVIDER}_API_KEY environment variables, "
|
||||
"or provide the base URL and API key as arguments."
|
||||
),
|
||||
)
|
||||
|
||||
def _validate_environment(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
|
@ -253,16 +281,26 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
headers: Optional[dict],
|
||||
) -> Tuple[str, dict]:
|
||||
if api_key is None and headers is None:
|
||||
if custom_endpoint:
|
||||
raise DatabricksError(
|
||||
status_code=400,
|
||||
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||
)
|
||||
else:
|
||||
api_base, headers = self._get_databricks_credentials(
|
||||
api_base=api_base, api_key=api_key, headers=headers
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
if custom_endpoint:
|
||||
raise DatabricksError(
|
||||
status_code=400,
|
||||
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||
)
|
||||
else:
|
||||
api_base, headers = self._get_databricks_credentials(
|
||||
api_base=api_base, api_key=api_key, headers=headers
|
||||
)
|
||||
|
||||
if headers is None:
|
||||
headers = {
|
||||
|
@ -273,6 +311,9 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
headers.update({"Authorization": "Bearer {}".format(api_key)})
|
||||
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
if endpoint_type == "chat_completions" and custom_endpoint is not True:
|
||||
api_base = "{}/chat/completions".format(api_base)
|
||||
elif endpoint_type == "embeddings" and custom_endpoint is not True:
|
||||
|
@ -520,7 +561,8 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
response_json = response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise DatabricksError(
|
||||
status_code=e.response.status_code, message=e.response.text
|
||||
status_code=e.response.status_code,
|
||||
message=e.response.text,
|
||||
)
|
||||
except httpx.TimeoutException as e:
|
||||
raise DatabricksError(
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
import types
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class FireworksAIConfig:
|
||||
"""
|
72
litellm/llms/fireworks_ai/cost_calculator.py
Normal file
72
litellm/llms/fireworks_ai/cost_calculator.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
"""
|
||||
For calculating cost of fireworks ai serverless inference models.
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from litellm.types.utils import Usage
|
||||
from litellm.utils import get_model_info
|
||||
|
||||
|
||||
# Extract the number of billion parameters from the model name
|
||||
# only used for together_computer LLMs
|
||||
def get_model_params_and_category(model_name: str) -> str:
|
||||
"""
|
||||
Helper function for calculating together ai pricing.
|
||||
|
||||
Returns:
|
||||
- str: model pricing category if mapped else received model name
|
||||
"""
|
||||
import re
|
||||
|
||||
model_name = model_name.lower()
|
||||
|
||||
# Check for MoE models in the form <number>x<number>b
|
||||
moe_match = re.search(r"(\d+)x(\d+)b", model_name)
|
||||
if moe_match:
|
||||
total_billion = int(moe_match.group(1)) * int(moe_match.group(2))
|
||||
if total_billion <= 56:
|
||||
return "fireworks-ai-moe-up-to-56b"
|
||||
elif total_billion <= 176:
|
||||
return "fireworks-ai-56b-to-176b"
|
||||
|
||||
# Check for standard models in the form <number>b
|
||||
re_params_match = re.search(r"(\d+)b", model_name)
|
||||
if re_params_match is not None:
|
||||
params_match = str(re_params_match.group(1))
|
||||
params_billion = float(params_match)
|
||||
|
||||
# Determine the category based on the number of parameters
|
||||
if params_billion <= 16.0:
|
||||
return "fireworks-ai-up-to-16b"
|
||||
elif params_billion <= 80.0:
|
||||
return "fireworks-ai-16b-80b"
|
||||
|
||||
# If no matches, return the original model_name
|
||||
return model_name
|
||||
|
||||
|
||||
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
||||
|
||||
Input:
|
||||
- model: str, the model name without provider prefix
|
||||
- usage: LiteLLM Usage block, containing anthropic caching information
|
||||
|
||||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
"""
|
||||
base_model = get_model_params_and_category(model_name=model)
|
||||
|
||||
## GET MODEL INFO
|
||||
model_info = get_model_info(model=base_model, custom_llm_provider="fireworks_ai")
|
||||
|
||||
## CALCULATE INPUT COST
|
||||
|
||||
prompt_cost: float = usage["prompt_tokens"] * model_info["input_cost_per_token"]
|
||||
|
||||
## CALCULATE OUTPUT COST
|
||||
completion_cost = usage["completion_tokens"] * model_info["output_cost_per_token"]
|
||||
|
||||
return prompt_cost, completion_cost
|
|
@ -7,6 +7,7 @@ import base64
|
|||
from httpx import Response
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.caching import InMemoryCache
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
|
@ -58,7 +59,7 @@ async def async_convert_url_to_base64(url: str) -> str:
|
|||
client = litellm.module_level_aclient
|
||||
for _ in range(3):
|
||||
try:
|
||||
response = await client.get(url)
|
||||
response = await client.get(url, follow_redirects=True)
|
||||
return _process_image_response(response, url)
|
||||
except:
|
||||
pass
|
||||
|
@ -75,9 +76,11 @@ def convert_url_to_base64(url: str) -> str:
|
|||
client = litellm.module_level_client
|
||||
for _ in range(3):
|
||||
try:
|
||||
response = client.get(url)
|
||||
response = client.get(url, follow_redirects=True)
|
||||
return _process_image_response(response, url)
|
||||
except:
|
||||
except Exception as e:
|
||||
verbose_logger.exception(e)
|
||||
# print(e)
|
||||
pass
|
||||
raise Exception(
|
||||
f"Error: Unable to fetch image from URL after 3 attempts. url={url}"
|
||||
|
|
|
@ -5452,6 +5452,26 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"source": "https://fireworks.ai/pricing"
|
||||
},
|
||||
"fireworks-ai-up-to-16b": {
|
||||
"input_cost_per_token": 0.0000002,
|
||||
"output_cost_per_token": 0.0000002,
|
||||
"litellm_provider": "fireworks_ai"
|
||||
},
|
||||
"fireworks-ai-16.1b-to-80b": {
|
||||
"input_cost_per_token": 0.0000009,
|
||||
"output_cost_per_token": 0.0000009,
|
||||
"litellm_provider": "fireworks_ai"
|
||||
},
|
||||
"fireworks-ai-moe-up-to-56b": {
|
||||
"input_cost_per_token": 0.0000005,
|
||||
"output_cost_per_token": 0.0000005,
|
||||
"litellm_provider": "fireworks_ai"
|
||||
},
|
||||
"fireworks-ai-56b-to-176b": {
|
||||
"input_cost_per_token": 0.0000012,
|
||||
"output_cost_per_token": 0.0000012,
|
||||
"litellm_provider": "fireworks_ai"
|
||||
},
|
||||
"anyscale/mistralai/Mistral-7B-Instruct-v0.1": {
|
||||
"max_tokens": 16384,
|
||||
|
|
|
@ -20,3 +20,10 @@ model_list:
|
|||
litellm_params:
|
||||
model: o1-preview
|
||||
|
||||
litellm_settings:
|
||||
cache: true
|
||||
# cache_params:
|
||||
# type: "redis"
|
||||
# service_name: "mymaster"
|
||||
# sentinel_nodes:
|
||||
# - ["localhost", 26379]
|
|
@ -653,27 +653,12 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
|||
return
|
||||
|
||||
try:
|
||||
from azure.identity import ClientSecretCredential, DefaultAzureCredential
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from azure.keyvault.secrets import SecretClient
|
||||
|
||||
# Set your Azure Key Vault URI
|
||||
KVUri = os.getenv("AZURE_KEY_VAULT_URI", None)
|
||||
|
||||
# Set your Azure AD application/client ID, client secret, and tenant ID
|
||||
client_id = os.getenv("AZURE_CLIENT_ID", None)
|
||||
client_secret = os.getenv("AZURE_CLIENT_SECRET", None)
|
||||
tenant_id = os.getenv("AZURE_TENANT_ID", None)
|
||||
|
||||
if (
|
||||
KVUri is not None
|
||||
and client_id is not None
|
||||
and client_secret is not None
|
||||
and tenant_id is not None
|
||||
):
|
||||
# Initialize the ClientSecretCredential
|
||||
# credential = ClientSecretCredential(
|
||||
# client_id=client_id, client_secret=client_secret, tenant_id=tenant_id
|
||||
# )
|
||||
credential = DefaultAzureCredential()
|
||||
|
||||
# Create the SecretClient using the credential
|
||||
|
@ -681,10 +666,6 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
|||
|
||||
litellm.secret_manager_client = client
|
||||
litellm._key_management_system = KeyManagementSystem.AZURE_KEY_VAULT
|
||||
else:
|
||||
raise Exception(
|
||||
f"Missing KVUri or client_id or client_secret or tenant_id from environment"
|
||||
)
|
||||
except Exception as e:
|
||||
_error_str = str(e)
|
||||
verbose_proxy_logger.exception(
|
||||
|
@ -1626,8 +1607,8 @@ class ProxyConfig:
|
|||
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
|
||||
self._init_cache(cache_params=cache_params)
|
||||
if litellm.cache is not None:
|
||||
verbose_proxy_logger.debug( # noqa
|
||||
f"{blue_color_code}Set Cache on LiteLLM Proxy= {vars(litellm.cache.cache)}{vars(litellm.cache)}{reset_color_code}"
|
||||
verbose_proxy_logger.debug(
|
||||
f"{blue_color_code}Set Cache on LiteLLM Proxy{reset_color_code}"
|
||||
)
|
||||
elif key == "cache" and value is False:
|
||||
pass
|
||||
|
|
|
@ -4019,7 +4019,9 @@ class Router:
|
|||
_model_info=_model_info,
|
||||
)
|
||||
|
||||
verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}")
|
||||
verbose_router_logger.debug(
|
||||
f"\nInitialized Model List {self.get_model_names()}"
|
||||
)
|
||||
self.model_names = [m["model_name"] for m in model_list]
|
||||
|
||||
def _add_deployment(self, deployment: Deployment) -> Deployment:
|
||||
|
@ -4630,6 +4632,7 @@ class Router:
|
|||
if hasattr(self, "model_list"):
|
||||
returned_models: List[DeploymentTypedDict] = []
|
||||
|
||||
if hasattr(self, "model_group_alias"):
|
||||
for model_alias, model_value in self.model_group_alias.items():
|
||||
|
||||
if isinstance(model_value, str):
|
||||
|
@ -5030,7 +5033,7 @@ class Router:
|
|||
# return the first deployment where the `model` matches the specificed deployment name
|
||||
return deployment_model, deployment
|
||||
raise ValueError(
|
||||
f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}"
|
||||
f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.get_model_names()}"
|
||||
)
|
||||
elif model in self.get_model_ids():
|
||||
deployment = self.get_model_info(id=model)
|
||||
|
|
|
@ -2045,3 +2045,57 @@ async def test_proxy_logging_setup():
|
|||
|
||||
pl_obj = ProxyLogging(user_api_key_cache=DualCache())
|
||||
assert pl_obj.internal_usage_cache.always_read_redis is True
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="local test. Requires sentinel setup.")
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_sentinel_caching():
|
||||
"""
|
||||
Init redis client
|
||||
- write to client
|
||||
- read from client
|
||||
"""
|
||||
litellm.set_verbose = False
|
||||
|
||||
random_number = random.randint(
|
||||
1, 100000
|
||||
) # add a random number to ensure it's always adding / reading from cache
|
||||
messages = [
|
||||
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
|
||||
]
|
||||
|
||||
litellm.cache = Cache(
|
||||
type="redis",
|
||||
# host=os.environ["REDIS_HOST"],
|
||||
# port=os.environ["REDIS_PORT"],
|
||||
# password=os.environ["REDIS_PASSWORD"],
|
||||
service_name="mymaster",
|
||||
sentinel_nodes=[("localhost", 26379)],
|
||||
)
|
||||
response1 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
cache_key = litellm.cache.get_cache_key(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
)
|
||||
print(f"cache_key: {cache_key}")
|
||||
litellm.cache.add_cache(result=response1, cache_key=cache_key)
|
||||
print(f"cache key pre async get: {cache_key}")
|
||||
stored_val = litellm.cache.get_cache(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
print(f"stored_val: {stored_val}")
|
||||
assert stored_val["id"] == response1.id
|
||||
|
||||
stored_val_2 = await litellm.cache.async_get_cache(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
print(f"stored_val: {stored_val}")
|
||||
assert stored_val_2["id"] == response1.id
|
||||
|
|
|
@ -1255,3 +1255,16 @@ def test_completion_cost_databricks_embedding(model):
|
|||
|
||||
print(resp)
|
||||
cost = completion_cost(completion_response=resp)
|
||||
|
||||
|
||||
def test_completion_cost_fireworks_ai():
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
|
||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
resp = litellm.completion(
|
||||
model="fireworks_ai/mixtral-8x7b-instruct", messages=messages
|
||||
) # works fine
|
||||
|
||||
print(resp)
|
||||
cost = completion_cost(completion_response=resp)
|
||||
|
|
|
@ -432,3 +432,7 @@ def test_vertex_only_image_user_message():
|
|||
), "Invalid gemini input. Got={}, Expected={}".format(
|
||||
content, expected_response[idx]
|
||||
)
|
||||
|
||||
|
||||
def test_convert_url():
|
||||
convert_url_to_base64("https://picsum.photos/id/237/200/300")
|
||||
|
|
|
@ -6830,7 +6830,10 @@ def exception_type(
|
|||
llm_provider=custom_llm_provider,
|
||||
model=model,
|
||||
)
|
||||
elif original_exception.status_code == 401:
|
||||
elif (
|
||||
original_exception.status_code == 401
|
||||
or original_exception.status_code == 403
|
||||
):
|
||||
exception_mapping_worked = True
|
||||
raise AuthenticationError(
|
||||
message=f"{custom_llm_provider}Exception - {original_exception.message}",
|
||||
|
|
|
@ -5452,6 +5452,26 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"source": "https://fireworks.ai/pricing"
|
||||
},
|
||||
"fireworks-ai-up-to-16b": {
|
||||
"input_cost_per_token": 0.0000002,
|
||||
"output_cost_per_token": 0.0000002,
|
||||
"litellm_provider": "fireworks_ai"
|
||||
},
|
||||
"fireworks-ai-16.1b-to-80b": {
|
||||
"input_cost_per_token": 0.0000009,
|
||||
"output_cost_per_token": 0.0000009,
|
||||
"litellm_provider": "fireworks_ai"
|
||||
},
|
||||
"fireworks-ai-moe-up-to-56b": {
|
||||
"input_cost_per_token": 0.0000005,
|
||||
"output_cost_per_token": 0.0000005,
|
||||
"litellm_provider": "fireworks_ai"
|
||||
},
|
||||
"fireworks-ai-56b-to-176b": {
|
||||
"input_cost_per_token": 0.0000012,
|
||||
"output_cost_per_token": 0.0000012,
|
||||
"litellm_provider": "fireworks_ai"
|
||||
},
|
||||
"anyscale/mistralai/Mistral-7B-Instruct-v0.1": {
|
||||
"max_tokens": 16384,
|
||||
|
|
|
@ -7,10 +7,17 @@ from typing import Any, Dict, List
|
|||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import litellm
|
||||
from litellm.exceptions import BadRequestError, InternalServerError
|
||||
from litellm.exceptions import BadRequestError
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
try:
|
||||
import databricks.sdk
|
||||
|
||||
databricks_sdk_installed = True
|
||||
except ImportError:
|
||||
databricks_sdk_installed = False
|
||||
|
||||
|
||||
def mock_chat_response() -> Dict[str, Any]:
|
||||
return {
|
||||
|
@ -33,8 +40,8 @@ def mock_chat_response() -> Dict[str, Any]:
|
|||
"usage": {
|
||||
"prompt_tokens": 230,
|
||||
"completion_tokens": 38,
|
||||
"total_tokens": 268,
|
||||
"completion_tokens_details": None,
|
||||
"total_tokens": 268,
|
||||
},
|
||||
"system_fingerprint": None,
|
||||
}
|
||||
|
@ -195,7 +202,14 @@ def mock_embedding_response() -> Dict[str, Any]:
|
|||
|
||||
|
||||
@pytest.mark.parametrize("set_base", [True, False])
|
||||
def test_throws_if_only_one_of_api_base_or_api_key_set(monkeypatch, set_base):
|
||||
def test_throws_if_api_base_or_api_key_not_set_without_databricks_sdk(
|
||||
monkeypatch, set_base
|
||||
):
|
||||
# Simulate that the databricks SDK is not installed
|
||||
monkeypatch.setitem(sys.modules, "databricks.sdk", None)
|
||||
|
||||
err_msg = "the Databricks base URL and API key are not set"
|
||||
|
||||
if set_base:
|
||||
monkeypatch.setenv(
|
||||
"DATABRICKS_API_BASE",
|
||||
|
@ -204,11 +218,11 @@ def test_throws_if_only_one_of_api_base_or_api_key_set(monkeypatch, set_base):
|
|||
monkeypatch.delenv(
|
||||
"DATABRICKS_API_KEY",
|
||||
)
|
||||
err_msg = "A call is being made to LLM Provider but no key is set"
|
||||
else:
|
||||
monkeypatch.setenv("DATABRICKS_API_KEY", "dapimykey")
|
||||
monkeypatch.delenv("DATABRICKS_API_BASE")
|
||||
err_msg = "A call is being made to LLM Provider but no api base is set"
|
||||
monkeypatch.delenv(
|
||||
"DATABRICKS_API_BASE",
|
||||
)
|
||||
|
||||
with pytest.raises(BadRequestError) as exc:
|
||||
litellm.completion(
|
||||
|
@ -422,6 +436,67 @@ def test_completions_streaming_with_async_http_handler(monkeypatch):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not databricks_sdk_installed, reason="Databricks SDK not installed")
|
||||
def test_completions_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch):
|
||||
from databricks.sdk import WorkspaceClient
|
||||
from databricks.sdk.config import Config
|
||||
|
||||
sync_handler = HTTPHandler()
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_chat_response()
|
||||
|
||||
expected_response_json = {
|
||||
**mock_chat_response(),
|
||||
**{
|
||||
"model": "databricks/dbrx-instruct-071224",
|
||||
},
|
||||
}
|
||||
|
||||
base_url = "https://my.workspace.cloud.databricks.com"
|
||||
api_key = "dapimykey"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
messages = [{"role": "user", "content": "How are you?"}]
|
||||
|
||||
mock_workspace_client: WorkspaceClient = MagicMock()
|
||||
mock_config: Config = MagicMock()
|
||||
# Simulate the behavior of the config property and its methods
|
||||
mock_config.authenticate.side_effect = lambda: headers
|
||||
mock_config.host = base_url # Assign directly as if it's a property
|
||||
mock_workspace_client.config = mock_config
|
||||
|
||||
with patch(
|
||||
"databricks.sdk.WorkspaceClient", return_value=mock_workspace_client
|
||||
), patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
|
||||
response = litellm.completion(
|
||||
model="databricks/dbrx-instruct-071224",
|
||||
messages=messages,
|
||||
client=sync_handler,
|
||||
temperature=0.5,
|
||||
extraparam="testpassingextraparam",
|
||||
)
|
||||
assert response.to_dict() == expected_response_json
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
f"{base_url}/serving-endpoints/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": "dbrx-instruct-071224",
|
||||
"messages": messages,
|
||||
"temperature": 0.5,
|
||||
"extraparam": "testpassingextraparam",
|
||||
"stream": False,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_embeddings_with_sync_http_handler(monkeypatch):
|
||||
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||
api_key = "dapimykey"
|
||||
|
@ -500,3 +575,59 @@ def test_embeddings_with_async_http_handler(monkeypatch):
|
|||
}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not databricks_sdk_installed, reason="Databricks SDK not installed")
|
||||
def test_embeddings_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch):
|
||||
from databricks.sdk import WorkspaceClient
|
||||
from databricks.sdk.config import Config
|
||||
|
||||
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||
api_key = "dapimykey"
|
||||
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
||||
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
||||
|
||||
sync_handler = HTTPHandler()
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_embedding_response()
|
||||
|
||||
base_url = "https://my.workspace.cloud.databricks.com"
|
||||
api_key = "dapimykey"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
inputs = ["Hello", "World"]
|
||||
|
||||
mock_workspace_client: WorkspaceClient = MagicMock()
|
||||
mock_config: Config = MagicMock()
|
||||
# Simulate the behavior of the config property and its methods
|
||||
mock_config.authenticate.side_effect = lambda: headers
|
||||
mock_config.host = base_url # Assign directly as if it's a property
|
||||
mock_workspace_client.config = mock_config
|
||||
|
||||
with patch(
|
||||
"databricks.sdk.WorkspaceClient", return_value=mock_workspace_client
|
||||
), patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
|
||||
response = litellm.embedding(
|
||||
model="databricks/bge-large-en-v1.5",
|
||||
input=inputs,
|
||||
client=sync_handler,
|
||||
extraparam="testpassingextraparam",
|
||||
)
|
||||
assert response.to_dict() == mock_embedding_response()
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
f"{base_url}/serving-endpoints/embeddings",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": "bge-large-en-v1.5",
|
||||
"input": inputs,
|
||||
"extraparam": "testpassingextraparam",
|
||||
}
|
||||
),
|
||||
)
|
||||
|
|
|
@ -7,7 +7,7 @@ sys.path.insert(
|
|||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.llms.fireworks_ai import FireworksAIConfig
|
||||
from litellm.llms.fireworks_ai.chat.fireworks_ai_transformation import FireworksAIConfig
|
||||
|
||||
fireworks = FireworksAIConfig()
|
||||
|
||||
|
|
|
@ -149,7 +149,9 @@ def test_all_model_configs():
|
|||
{"max_completion_tokens": 10}, {}, "llama3"
|
||||
) == {"max_tokens": 10}
|
||||
|
||||
from litellm.llms.fireworks_ai import FireworksAIConfig
|
||||
from litellm.llms.fireworks_ai.chat.fireworks_ai_transformation import (
|
||||
FireworksAIConfig,
|
||||
)
|
||||
|
||||
assert "max_completion_tokens" in FireworksAIConfig().get_supported_openai_params()
|
||||
assert FireworksAIConfig().map_openai_params(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue