LiteLLM Minor Fixes & Improvements (09/23/2024) (#5842) (#5858)

* LiteLLM Minor Fixes & Improvements (09/23/2024)  (#5842)

* feat(auth_utils.py): enable admin to allow client-side credentials to be passed

Makes it easier for devs to experiment with finetuned fireworks ai models

* feat(router.py): allow setting configurable_clientside_auth_params for a model

Closes https://github.com/BerriAI/litellm/issues/5843

* build(model_prices_and_context_window.json): fix anthropic claude-3-5-sonnet max output token limit

Fixes https://github.com/BerriAI/litellm/issues/5850

* fix(azure_ai/): support content list for azure ai

Fixes https://github.com/BerriAI/litellm/issues/4237

* fix(litellm_logging.py): always set saved_cache_cost

Set to 0 by default

* fix(fireworks_ai/cost_calculator.py): add fireworks ai default pricing

handles calling 405b+ size models

* fix(slack_alerting.py): fix error alerting for failed spend tracking

Fixes regression with slack alerting error monitoring

* fix(vertex_and_google_ai_studio_gemini.py): handle gemini no candidates in streaming chunk error

* docs(bedrock.md): add llama3-1 models

* test: fix tests

* fix(azure_ai/chat): fix transformation for azure ai calls
This commit is contained in:
Krish Dholakia 2024-09-24 15:01:31 -07:00 committed by GitHub
parent 4df9aca45e
commit d37c8b5c6b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 611 additions and 294 deletions

View file

@ -987,6 +987,9 @@ Here's an example of using a bedrock model with LiteLLM. For a complete list, re
| Anthropic Claude-V2.1 | `completion(model='bedrock/anthropic.claude-v2:1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` | | Anthropic Claude-V2.1 | `completion(model='bedrock/anthropic.claude-v2:1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Anthropic Claude-V2 | `completion(model='bedrock/anthropic.claude-v2', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` | | Anthropic Claude-V2 | `completion(model='bedrock/anthropic.claude-v2', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Anthropic Claude-Instant V1 | `completion(model='bedrock/anthropic.claude-instant-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` | | Anthropic Claude-Instant V1 | `completion(model='bedrock/anthropic.claude-instant-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-1-405b | `completion(model='bedrock/meta.llama3-1-405b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-1-70b | `completion(model='bedrock/meta.llama3-1-70b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-1-8b | `completion(model='bedrock/meta.llama3-1-8b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-70b | `completion(model='bedrock/meta.llama3-70b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` | | Meta llama3-70b | `completion(model='bedrock/meta.llama3-70b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-8b | `completion(model='bedrock/meta.llama3-8b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` | | Meta llama3-8b | `completion(model='bedrock/meta.llama3-8b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Amazon Titan Lite | `completion(model='bedrock/amazon.titan-text-lite-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` | | Amazon Titan Lite | `completion(model='bedrock/amazon.titan-text-lite-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |

View file

@ -963,8 +963,8 @@ from .llms.OpenAI.openai import (
MistralEmbeddingConfig, MistralEmbeddingConfig,
DeepInfraConfig, DeepInfraConfig,
GroqConfig, GroqConfig,
AzureAIStudioConfig,
) )
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
from .llms.mistral.mistral_chat_transformation import MistralConfig from .llms.mistral.mistral_chat_transformation import MistralConfig
from .llms.OpenAI.chat.o1_transformation import ( from .llms.OpenAI.chat.o1_transformation import (
OpenAIO1Config, OpenAIO1Config,

View file

@ -10,7 +10,7 @@ import traceback
from datetime import datetime as dt from datetime import datetime as dt
from datetime import timedelta, timezone from datetime import timedelta, timezone
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Set, TypedDict, Union from typing import Any, Dict, List, Literal, Optional, Set, TypedDict, Union, get_args
import aiohttp import aiohttp
import dotenv import dotenv
@ -57,20 +57,7 @@ class SlackAlerting(CustomBatchLogger):
float float
] = None, # threshold for slow / hanging llm responses (in seconds) ] = None, # threshold for slow / hanging llm responses (in seconds)
alerting: Optional[List] = [], alerting: Optional[List] = [],
alert_types: List[AlertType] = [ alert_types: List[AlertType] = list(get_args(AlertType)),
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
"daily_reports",
"spend_reports",
"fallback_reports",
"cooldown_deployment",
"new_model_added",
"outage_alerts",
"failed_tracking_spend",
],
alert_to_webhook_url: Optional[ alert_to_webhook_url: Optional[
Dict[AlertType, Union[List[str], str]] Dict[AlertType, Union[List[str], str]]
] = None, # if user wants to separate alerts to diff channels ] = None, # if user wants to separate alerts to diff channels
@ -613,7 +600,7 @@ class SlackAlerting(CustomBatchLogger):
await self.send_alert( await self.send_alert(
message=message, message=message,
level="High", level="High",
alert_type="budget_alerts", alert_type="failed_tracking_spend",
alerting_metadata={}, alerting_metadata={},
) )
await _cache.async_set_cache( await _cache.async_set_cache(

View file

@ -2498,14 +2498,17 @@ def get_standard_logging_object_payload(
else: else:
cache_key = None cache_key = None
saved_cache_cost: Optional[float] = None saved_cache_cost: float = 0.0
if cache_hit is True: if cache_hit is True:
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
saved_cache_cost = logging_obj._response_cost_calculator( saved_cache_cost = (
logging_obj._response_cost_calculator(
result=init_response_obj, cache_hit=False # type: ignore result=init_response_obj, cache_hit=False # type: ignore
) )
or 0.0
)
## Get model cost information ## ## Get model cost information ##
base_model = _get_base_model_from_metadata(model_call_details=kwargs) base_model = _get_base_model_from_metadata(model_call_details=kwargs)

View file

@ -103,25 +103,6 @@ class MistralEmbeddingConfig:
return optional_params return optional_params
class AzureAIStudioConfig:
def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description"""
return [
ProviderField(
field_name="api_key",
field_type="string",
field_description="Your Azure AI Studio API Key.",
field_value="zEJ...",
),
ProviderField(
field_name="api_base",
field_type="string",
field_description="Your Azure AI Studio API Base.",
field_value="https://Mistral-serverless.",
),
]
class DeepInfraConfig: class DeepInfraConfig:
""" """
Reference: https://deepinfra.com/docs/advanced/openai_api Reference: https://deepinfra.com/docs/advanced/openai_api

View file

@ -0,0 +1,59 @@
from typing import Any, Callable, List, Optional, Union
from httpx._config import Timeout
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper
from .transformation import AzureAIStudioConfig
class AzureAIChatCompletion(OpenAIChatCompletion):
def completion(
self,
model_response: ModelResponse,
timeout: Union[float, Timeout],
optional_params: dict,
logging_obj: Any,
model: Optional[str] = None,
messages: Optional[list] = None,
print_verbose: Optional[Callable[..., Any]] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
acompletion: bool = False,
litellm_params=None,
logger_fn=None,
headers: Optional[dict] = None,
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
drop_params: Optional[bool] = None,
):
transformed_messages = AzureAIStudioConfig()._transform_messages(
messages=messages # type: ignore
)
return super().completion(
model_response,
timeout,
optional_params,
logging_obj,
model,
transformed_messages,
print_verbose,
api_key,
api_base,
acompletion,
litellm_params,
logger_fn,
headers,
custom_prompt_dict,
client,
organization,
custom_llm_provider,
drop_params,
)

View file

@ -0,0 +1,31 @@
from typing import List
from litellm.llms.OpenAI.openai import OpenAIConfig
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ProviderField
class AzureAIStudioConfig(OpenAIConfig):
def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description"""
return [
ProviderField(
field_name="api_key",
field_type="string",
field_description="Your Azure AI Studio API Key.",
field_value="zEJ...",
),
ProviderField(
field_name="api_base",
field_type="string",
field_description="Your Azure AI Studio API Base.",
field_value="https://Mistral-serverless.",
),
]
def _transform_messages(self, messages: List[AllMessageValues]) -> List:
for message in messages:
message = convert_content_list_to_str(message=message)
return messages

View file

@ -10,7 +10,7 @@ from litellm.utils import get_model_info
# Extract the number of billion parameters from the model name # Extract the number of billion parameters from the model name
# only used for together_computer LLMs # only used for together_computer LLMs
def get_model_params_and_category(model_name: str) -> str: def get_base_model_for_pricing(model_name: str) -> str:
""" """
Helper function for calculating together ai pricing. Helper function for calculating together ai pricing.
@ -43,7 +43,7 @@ def get_model_params_and_category(model_name: str) -> str:
return "fireworks-ai-16b-80b" return "fireworks-ai-16b-80b"
# If no matches, return the original model_name # If no matches, return the original model_name
return model_name return "fireworks-ai-default"
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]: def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
@ -57,10 +57,16 @@ def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
Returns: Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
""" """
base_model = get_model_params_and_category(model_name=model) ## check if model mapped, else use default pricing
try:
model_info = get_model_info(model=model, custom_llm_provider="fireworks_ai")
except Exception:
base_model = get_base_model_for_pricing(model_name=model)
## GET MODEL INFO ## GET MODEL INFO
model_info = get_model_info(model=base_model, custom_llm_provider="fireworks_ai") model_info = get_model_info(
model=base_model, custom_llm_provider="fireworks_ai"
)
## CALCULATE INPUT COST ## CALCULATE INPUT COST

View file

@ -0,0 +1,32 @@
"""
Common utility functions used for translating messages across providers
"""
from typing import List
from litellm.types.llms.openai import AllMessageValues
def convert_content_list_to_str(message: AllMessageValues) -> AllMessageValues:
"""
- handles scenario where content is list and not string
- content list is just text, and no images
- if image passed in, then just return as is (user-intended)
Motivation: mistral api + azure ai don't support content as a list
"""
texts = ""
message_content = message.get("content")
if message_content:
if message_content is not None and isinstance(message_content, list):
for c in message_content:
text_content = c.get("text")
if text_content:
texts += text_content
elif message_content is not None and isinstance(message_content, str):
texts = message_content
if texts:
message["content"] = texts
return message

View file

@ -49,6 +49,7 @@ from litellm.types.llms.openai import (
ChatCompletionUsageBlock, ChatCompletionUsageBlock,
) )
from litellm.types.llms.vertex_ai import ( from litellm.types.llms.vertex_ai import (
Candidates,
ContentType, ContentType,
FunctionCallingConfig, FunctionCallingConfig,
FunctionDeclaration, FunctionDeclaration,
@ -187,7 +188,11 @@ class VertexAIConfig:
optional_params["stop_sequences"] = value optional_params["stop_sequences"] = value
if param == "max_tokens" or param == "max_completion_tokens": if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_output_tokens"] = value optional_params["max_output_tokens"] = value
if param == "response_format" and value["type"] == "json_object": if (
param == "response_format"
and isinstance(value, dict)
and value["type"] == "json_object"
):
optional_params["response_mime_type"] = "application/json" optional_params["response_mime_type"] = "application/json"
if param == "frequency_penalty": if param == "frequency_penalty":
optional_params["frequency_penalty"] = value optional_params["frequency_penalty"] = value
@ -900,14 +905,14 @@ class VertexLLM(VertexBase):
return model_response return model_response
if len(completion_response["candidates"]) > 0: _candidates = completion_response.get("candidates")
if _candidates and len(_candidates) > 0:
content_policy_violations = ( content_policy_violations = (
VertexGeminiConfig().get_flagged_finish_reasons() VertexGeminiConfig().get_flagged_finish_reasons()
) )
if ( if (
"finishReason" in completion_response["candidates"][0] "finishReason" in _candidates[0]
and completion_response["candidates"][0]["finishReason"] and _candidates[0]["finishReason"] in content_policy_violations.keys()
in content_policy_violations.keys()
): ):
## CONTENT POLICY VIOLATION ERROR ## CONTENT POLICY VIOLATION ERROR
model_response.choices[0].finish_reason = "content_filter" model_response.choices[0].finish_reason = "content_filter"
@ -956,12 +961,13 @@ class VertexLLM(VertexBase):
content_str = "" content_str = ""
tools: List[ChatCompletionToolCallChunk] = [] tools: List[ChatCompletionToolCallChunk] = []
functions: Optional[ChatCompletionToolCallFunctionChunk] = None functions: Optional[ChatCompletionToolCallFunctionChunk] = None
for idx, candidate in enumerate(completion_response["candidates"]): if _candidates:
for idx, candidate in enumerate(_candidates):
if "content" not in candidate: if "content" not in candidate:
continue continue
if "groundingMetadata" in candidate: if "groundingMetadata" in candidate:
grounding_metadata.append(candidate["groundingMetadata"]) grounding_metadata.append(candidate["groundingMetadata"]) # type: ignore
if "safetyRatings" in candidate: if "safetyRatings" in candidate:
safety_ratings.append(candidate["safetyRatings"]) safety_ratings.append(candidate["safetyRatings"])
@ -973,7 +979,9 @@ class VertexLLM(VertexBase):
if "functionCall" in candidate["content"]["parts"][0]: if "functionCall" in candidate["content"]["parts"][0]:
_function_chunk = ChatCompletionToolCallFunctionChunk( _function_chunk = ChatCompletionToolCallFunctionChunk(
name=candidate["content"]["parts"][0]["functionCall"]["name"], name=candidate["content"]["parts"][0]["functionCall"][
"name"
],
arguments=json.dumps( arguments=json.dumps(
candidate["content"]["parts"][0]["functionCall"]["args"] candidate["content"]["parts"][0]["functionCall"]["args"]
), ),
@ -1433,10 +1441,12 @@ class ModelResponseIterator:
is_finished = False is_finished = False
finish_reason = "" finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None usage: Optional[ChatCompletionUsageBlock] = None
_candidates: Optional[List[Candidates]] = processed_chunk.get("candidates")
gemini_chunk: Optional[Candidates] = None
if _candidates and len(_candidates) > 0:
gemini_chunk = _candidates[0]
gemini_chunk = processed_chunk["candidates"][0] if gemini_chunk and "content" in gemini_chunk:
if "content" in gemini_chunk:
if "text" in gemini_chunk["content"]["parts"][0]: if "text" in gemini_chunk["content"]["parts"][0]:
text = gemini_chunk["content"]["parts"][0]["text"] text = gemini_chunk["content"]["parts"][0]["text"]
elif "functionCall" in gemini_chunk["content"]["parts"][0]: elif "functionCall" in gemini_chunk["content"]["parts"][0]:
@ -1455,7 +1465,7 @@ class ModelResponseIterator:
index=0, index=0,
) )
if "finishReason" in gemini_chunk: if gemini_chunk and "finishReason" in gemini_chunk:
finish_reason = map_finish_reason( finish_reason = map_finish_reason(
finish_reason=gemini_chunk["finishReason"] finish_reason=gemini_chunk["finishReason"]
) )
@ -1533,6 +1543,7 @@ class ModelResponseIterator:
) )
def _common_chunk_parsing_logic(self, chunk: str) -> GenericStreamingChunk: def _common_chunk_parsing_logic(self, chunk: str) -> GenericStreamingChunk:
try:
chunk = chunk.replace("data:", "") chunk = chunk.replace("data:", "")
if len(chunk) > 0: if len(chunk) > 0:
""" """
@ -1544,7 +1555,7 @@ class ModelResponseIterator:
return self.handle_valid_json_chunk(chunk=chunk) return self.handle_valid_json_chunk(chunk=chunk)
elif self.chunk_type == "accumulated_json": elif self.chunk_type == "accumulated_json":
return self.handle_accumulated_json_chunk(chunk=chunk) return self.handle_accumulated_json_chunk(chunk=chunk)
else:
return GenericStreamingChunk( return GenericStreamingChunk(
text="", text="",
is_finished=False, is_finished=False,
@ -1553,6 +1564,8 @@ class ModelResponseIterator:
index=0, index=0,
tool_use=None, tool_use=None,
) )
except Exception:
raise
def __next__(self): def __next__(self):
try: try:

View file

@ -83,6 +83,7 @@ from .llms import (
from .llms.AI21 import completion as ai21 from .llms.AI21 import completion as ai21
from .llms.anthropic.chat import AnthropicChatCompletion from .llms.anthropic.chat import AnthropicChatCompletion
from .llms.anthropic.completion import AnthropicTextCompletion from .llms.anthropic.completion import AnthropicTextCompletion
from .llms.azure_ai.chat.handler import AzureAIChatCompletion
from .llms.azure_text import AzureTextCompletion from .llms.azure_text import AzureTextCompletion
from .llms.AzureOpenAI.audio_transcriptions import AzureAudioTranscription from .llms.AzureOpenAI.audio_transcriptions import AzureAudioTranscription
from .llms.AzureOpenAI.azure import AzureChatCompletion, _check_dynamic_azure_params from .llms.AzureOpenAI.azure import AzureChatCompletion, _check_dynamic_azure_params
@ -166,6 +167,7 @@ openai_text_completions = OpenAITextCompletion()
openai_o1_chat_completions = OpenAIO1ChatCompletion() openai_o1_chat_completions = OpenAIO1ChatCompletion()
openai_audio_transcriptions = OpenAIAudioTranscription() openai_audio_transcriptions = OpenAIAudioTranscription()
databricks_chat_completions = DatabricksChatCompletion() databricks_chat_completions = DatabricksChatCompletion()
azure_ai_chat_completions = AzureAIChatCompletion()
anthropic_chat_completions = AnthropicChatCompletion() anthropic_chat_completions = AnthropicChatCompletion()
anthropic_text_completions = AnthropicTextCompletion() anthropic_text_completions = AnthropicTextCompletion()
azure_chat_completions = AzureChatCompletion() azure_chat_completions = AzureChatCompletion()
@ -1177,7 +1179,7 @@ def completion(
headers = headers or litellm.headers headers = headers or litellm.headers
## LOAD CONFIG - if set ## LOAD CONFIG - if set
config = litellm.OpenAIConfig.get_config() config = litellm.AzureAIStudioConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if ( if (
k not in optional_params k not in optional_params
@ -1190,7 +1192,7 @@ def completion(
## COMPLETION CALL ## COMPLETION CALL
try: try:
response = openai_chat_completions.completion( response = azure_ai_chat_completions.completion(
model=model, model=model,
messages=messages, messages=messages,
headers=headers, headers=headers,

View file

@ -3862,9 +3862,9 @@
"supports_vision": true "supports_vision": true
}, },
"anthropic.claude-3-5-sonnet-20240620-v1:0": { "anthropic.claude-3-5-sonnet-20240620-v1:0": {
"max_tokens": 8192, "max_tokens": 4096,
"max_input_tokens": 200000, "max_input_tokens": 200000,
"max_output_tokens": 8192, "max_output_tokens": 4096,
"input_cost_per_token": 0.000003, "input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015, "output_cost_per_token": 0.000015,
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
@ -3906,9 +3906,9 @@
"supports_vision": true "supports_vision": true
}, },
"us.anthropic.claude-3-5-sonnet-20240620-v1:0": { "us.anthropic.claude-3-5-sonnet-20240620-v1:0": {
"max_tokens": 8192, "max_tokens": 4096,
"max_input_tokens": 200000, "max_input_tokens": 200000,
"max_output_tokens": 8192, "max_output_tokens": 4096,
"input_cost_per_token": 0.000003, "input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015, "output_cost_per_token": 0.000015,
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
@ -3939,9 +3939,9 @@
"supports_vision": true "supports_vision": true
}, },
"eu.anthropic.claude-3-sonnet-20240229-v1:0": { "eu.anthropic.claude-3-sonnet-20240229-v1:0": {
"max_tokens": 8192, "max_tokens": 4096,
"max_input_tokens": 200000, "max_input_tokens": 200000,
"max_output_tokens": 8192, "max_output_tokens": 4096,
"input_cost_per_token": 0.000003, "input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015, "output_cost_per_token": 0.000015,
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
@ -3950,9 +3950,9 @@
"supports_vision": true "supports_vision": true
}, },
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0": { "eu.anthropic.claude-3-5-sonnet-20240620-v1:0": {
"max_tokens": 8192, "max_tokens": 4096,
"max_input_tokens": 200000, "max_input_tokens": 200000,
"max_output_tokens": 8192, "max_output_tokens": 4096,
"input_cost_per_token": 0.000003, "input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015, "output_cost_per_token": 0.000015,
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
@ -5593,6 +5593,11 @@
"output_cost_per_token": 0.0000012, "output_cost_per_token": 0.0000012,
"litellm_provider": "fireworks_ai" "litellm_provider": "fireworks_ai"
}, },
"fireworks-ai-default": {
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "fireworks_ai"
},
"fireworks-ai-embedding-up-to-150m": { "fireworks-ai-embedding-up-to-150m": {
"input_cost_per_token": 0.000000008, "input_cost_per_token": 0.000000008,
"output_cost_per_token": 0.000000, "output_cost_per_token": 0.000000,

View file

@ -31,6 +31,13 @@ model_list:
- model_name: "anthropic/*" - model_name: "anthropic/*"
litellm_params: litellm_params:
model: "anthropic/*" model: "anthropic/*"
- model_name: "openai/*"
litellm_params:
model: "openai/*"
- model_name: "fireworks_ai/*"
litellm_params:
model: "fireworks_ai/*"
configurable_clientside_auth_params: ["api_base"]
litellm_settings: litellm_settings:

View file

@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
from fastapi import HTTPException, Request, status from fastapi import HTTPException, Request, status
from litellm import Router, provider_list
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import * from litellm.proxy._types import *
@ -72,7 +73,41 @@ def check_complete_credentials(request_body: dict) -> bool:
return False return False
def is_request_body_safe(request_body: dict) -> bool: def _allow_model_level_clientside_configurable_parameters(
model: str, param: str, llm_router: Optional[Router]
) -> bool:
"""
Check if model is allowed to use configurable client-side params
- get matching model
- check if 'clientside_configurable_parameters' is set for model
-
"""
if llm_router is None:
return False
# check if model is set
model_info = llm_router.get_model_group_info(model_group=model)
if model_info is None:
# check if wildcard model is set
if model.split("/", 1)[0] in provider_list:
model_info = llm_router.get_model_group_info(
model_group=model.split("/", 1)[0]
)
if model_info is None:
return False
if model_info is None or model_info.configurable_clientside_auth_params is None:
return False
if param in model_info.configurable_clientside_auth_params:
return True
return False
def is_request_body_safe(
request_body: dict, general_settings: dict, llm_router: Optional[Router], model: str
) -> bool:
""" """
Check if the request body is safe. Check if the request body is safe.
@ -88,7 +123,20 @@ def is_request_body_safe(request_body: dict) -> bool:
request_body=request_body request_body=request_body
) )
): ):
raise ValueError(f"BadRequest: {param} is not allowed in request body") if general_settings.get("allow_client_side_credentials") is True:
return True
elif (
_allow_model_level_clientside_configurable_parameters(
model=model, param=param, llm_router=llm_router
)
is True
):
return True
raise ValueError(
f"Rejected Request: {param} is not allowed in request body. "
"Enable with `general_settings::allow_client_side_credentials` on proxy config.yaml. "
"Relevant Issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997",
)
return True return True
@ -110,13 +158,20 @@ async def pre_db_read_auth_checks(
Raises: Raises:
- HTTPException if request fails initial auth checks - HTTPException if request fails initial auth checks
""" """
from litellm.proxy.proxy_server import general_settings, premium_user from litellm.proxy.proxy_server import general_settings, llm_router, premium_user
# Check 1. request size # Check 1. request size
await check_if_request_size_is_safe(request=request) await check_if_request_size_is_safe(request=request)
# Check 2. Request body is safe # Check 2. Request body is safe
is_request_body_safe(request_body=request_data) is_request_body_safe(
request_body=request_data,
general_settings=general_settings,
llm_router=llm_router,
model=request_data.get(
"model", ""
), # [TODO] use model passed in url as well (azure openai routes)
)
# Check 3. Check if IP address is allowed # Check 3. Check if IP address is allowed
is_valid_ip, passed_in_ip = _check_valid_ip( is_valid_ip, passed_in_ip = _check_valid_ip(

View file

@ -66,7 +66,7 @@ async def route_request(
""" """
router_model_names = llm_router.model_names if llm_router is not None else [] router_model_names = llm_router.model_names if llm_router is not None else []
if "api_key" in data: if "api_key" in data or "api_base" in data:
return getattr(litellm, f"{route_type}")(**data) return getattr(litellm, f"{route_type}")(**data)
elif "user_config" in data: elif "user_config" in data:

View file

@ -4,6 +4,8 @@ import secrets
import traceback import traceback
from typing import Optional from typing import Optional
from pydantic import BaseModel
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
@ -105,6 +107,8 @@ def get_logging_payload(
additional_usage_values = {} additional_usage_values = {}
for k, v in usage.items(): for k, v in usage.items():
if k not in special_usage_fields: if k not in special_usage_fields:
if isinstance(v, BaseModel):
v = v.model_dump()
additional_usage_values.update({k: v}) additional_usage_values.update({k: v})
clean_metadata["additional_usage_values"] = additional_usage_values clean_metadata["additional_usage_values"] = additional_usage_values

View file

@ -14,7 +14,17 @@ from datetime import datetime, timedelta
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText from email.mime.text import MIMEText
from functools import wraps from functools import wraps
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union, overload from typing import (
TYPE_CHECKING,
Any,
List,
Literal,
Optional,
Tuple,
Union,
get_args,
overload,
)
import backoff import backoff
import httpx import httpx
@ -222,19 +232,7 @@ class ProxyLogging:
self.cache_control_check = _PROXY_CacheControlCheck() self.cache_control_check = _PROXY_CacheControlCheck()
self.alerting: Optional[List] = None self.alerting: Optional[List] = None
self.alerting_threshold: float = 300 # default to 5 min. threshold self.alerting_threshold: float = 300 # default to 5 min. threshold
self.alert_types: List[AlertType] = [ self.alert_types: List[AlertType] = list(get_args(AlertType))
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
"daily_reports",
"spend_reports",
"fallback_reports",
"cooldown_deployment",
"new_model_added",
"outage_alerts",
]
self.alert_to_webhook_url: Optional[dict] = None self.alert_to_webhook_url: Optional[dict] = None
self.slack_alerting_instance: SlackAlerting = SlackAlerting( self.slack_alerting_instance: SlackAlerting = SlackAlerting(
alerting_threshold=self.alerting_threshold, alerting_threshold=self.alerting_threshold,

View file

@ -4335,11 +4335,28 @@ class Router:
total_tpm: Optional[int] = None total_tpm: Optional[int] = None
total_rpm: Optional[int] = None total_rpm: Optional[int] = None
configurable_clientside_auth_params: Optional[List[str]] = None
for model in self.model_list: for model in self.model_list:
if "model_name" in model and model["model_name"] == model_group: is_match = False
if (
"model_name" in model and model["model_name"] == model_group
): # exact match
is_match = True
elif (
"model_name" in model
and model_group in self.provider_default_deployments
): # wildcard model
is_match = True
if not is_match:
continue
# model in model group found # # model in model group found #
litellm_params = LiteLLM_Params(**model["litellm_params"]) litellm_params = LiteLLM_Params(**model["litellm_params"])
# get configurable clientside auth params
configurable_clientside_auth_params = (
litellm_params.configurable_clientside_auth_params
)
# get model tpm # get model tpm
_deployment_tpm: Optional[int] = None _deployment_tpm: Optional[int] = None
if _deployment_tpm is None: if _deployment_tpm is None:
@ -4425,9 +4442,7 @@ class Router:
> model_group_info.max_input_tokens > model_group_info.max_input_tokens
) )
): ):
model_group_info.max_input_tokens = model_info[ model_group_info.max_input_tokens = model_info["max_input_tokens"]
"max_input_tokens"
]
if ( if (
model_info.get("max_output_tokens", None) is not None model_info.get("max_output_tokens", None) is not None
and model_info["max_output_tokens"] is not None and model_info["max_output_tokens"] is not None
@ -4437,9 +4452,7 @@ class Router:
> model_group_info.max_output_tokens > model_group_info.max_output_tokens
) )
): ):
model_group_info.max_output_tokens = model_info[ model_group_info.max_output_tokens = model_info["max_output_tokens"]
"max_output_tokens"
]
if model_info.get("input_cost_per_token", None) is not None and ( if model_info.get("input_cost_per_token", None) is not None and (
model_group_info.input_cost_per_token is None model_group_info.input_cost_per_token is None
or model_info["input_cost_per_token"] or model_info["input_cost_per_token"]
@ -4480,13 +4493,20 @@ class Router:
"supported_openai_params" "supported_openai_params"
] ]
if model_group_info is not None:
## UPDATE WITH TOTAL TPM/RPM FOR MODEL GROUP ## UPDATE WITH TOTAL TPM/RPM FOR MODEL GROUP
if total_tpm is not None and model_group_info is not None: if total_tpm is not None:
model_group_info.tpm = total_tpm model_group_info.tpm = total_tpm
if total_rpm is not None and model_group_info is not None: if total_rpm is not None:
model_group_info.rpm = total_rpm model_group_info.rpm = total_rpm
## UPDATE WITH CONFIGURABLE CLIENTSIDE AUTH PARAMS FOR MODEL GROUP
if configurable_clientside_auth_params is not None:
model_group_info.configurable_clientside_auth_params = (
configurable_clientside_auth_params
)
return model_group_info return model_group_info
def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]: def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:

View file

@ -141,9 +141,16 @@ def test_completion_azure_ai_command_r():
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "") os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "")
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "") os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "")
response: litellm.ModelResponse = completion( response = completion(
model="azure_ai/command-r-plus", model="azure_ai/command-r-plus",
messages=[{"role": "user", "content": "What is the meaning of life?"}], messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What is the meaning of life?"}
],
}
],
) # type: ignore ) # type: ignore
assert "azure_ai" in response.model assert "azure_ai" in response.model

View file

@ -1257,14 +1257,31 @@ def test_completion_cost_databricks_embedding(model):
cost = completion_cost(completion_response=resp) cost = completion_cost(completion_response=resp)
def test_completion_cost_fireworks_ai(): from litellm.llms.fireworks_ai.cost_calculator import get_base_model_for_pricing
@pytest.mark.parametrize(
"model, base_model",
[
("fireworks_ai/llama-v3p1-405b-instruct", "fireworks-ai-default"),
("fireworks_ai/mixtral-8x7b-instruct", "fireworks-ai-moe-up-to-56b"),
],
)
def test_get_model_params_fireworks_ai(model, base_model):
pricing_model = get_base_model_for_pricing(model_name=model)
assert base_model == pricing_model
@pytest.mark.parametrize(
"model",
["fireworks_ai/llama-v3p1-405b-instruct", "fireworks_ai/mixtral-8x7b-instruct"],
)
def test_completion_cost_fireworks_ai(model):
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="") litellm.model_cost = litellm.get_model_cost_map(url="")
messages = [{"role": "user", "content": "Hey, how's it going?"}] messages = [{"role": "user", "content": "Hey, how's it going?"}]
resp = litellm.completion( resp = litellm.completion(model=model, messages=messages) # works fine
model="fireworks_ai/mixtral-8x7b-instruct", messages=messages
) # works fine
print(resp) print(resp)
cost = completion_cost(completion_response=resp) cost = completion_cost(completion_response=resp)

View file

@ -12,6 +12,7 @@ sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.auth.auth_utils import is_request_body_safe
from litellm.proxy.litellm_pre_call_utils import ( from litellm.proxy.litellm_pre_call_utils import (
_get_dynamic_logging_metadata, _get_dynamic_logging_metadata,
add_litellm_data_to_request, add_litellm_data_to_request,
@ -291,3 +292,78 @@ def test_dynamic_logging_metadata_key_and_team_metadata(callback_vars):
for var in callbacks.callback_vars.values(): for var in callbacks.callback_vars.values():
assert "os.environ" not in var assert "os.environ" not in var
@pytest.mark.parametrize(
"allow_client_side_credentials, expect_error", [(True, False), (False, True)]
)
def test_is_request_body_safe_global_enabled(
allow_client_side_credentials, expect_error
):
from litellm import Router
error_raised = False
llm_router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
}
]
)
try:
is_request_body_safe(
request_body={"api_base": "hello-world"},
general_settings={
"allow_client_side_credentials": allow_client_side_credentials
},
llm_router=llm_router,
model="gpt-3.5-turbo",
)
except Exception as e:
print(e)
error_raised = True
assert expect_error == error_raised
@pytest.mark.parametrize(
"allow_client_side_credentials, expect_error", [(True, False), (False, True)]
)
def test_is_request_body_safe_model_enabled(
allow_client_side_credentials, expect_error
):
from litellm import Router
error_raised = False
llm_router = Router(
model_list=[
{
"model_name": "fireworks_ai/*",
"litellm_params": {
"model": "fireworks_ai/*",
"api_key": os.getenv("FIREWORKS_API_KEY"),
"configurable_clientside_auth_params": (
["api_base"] if allow_client_side_credentials else []
),
},
}
]
)
try:
is_request_body_safe(
request_body={"api_base": "hello-world"},
general_settings={},
llm_router=llm_router,
model="fireworks_ai/my-new-model",
)
except Exception as e:
print(e)
error_raised = True
assert expect_error == error_raised

View file

@ -283,7 +283,7 @@ class PromptFeedback(TypedDict):
class GenerateContentResponseBody(TypedDict, total=False): class GenerateContentResponseBody(TypedDict, total=False):
candidates: Required[List[Candidates]] candidates: List[Candidates]
promptFeedback: PromptFeedback promptFeedback: PromptFeedback
usageMetadata: Required[UsageMetadata] usageMetadata: Required[UsageMetadata]

View file

@ -139,6 +139,7 @@ class GenericLiteLLMParams(BaseModel):
) )
max_retries: Optional[int] = None max_retries: Optional[int] = None
organization: Optional[str] = None # for openai orgs organization: Optional[str] = None # for openai orgs
configurable_clientside_auth_params: Optional[List[str]] = None
## UNIFIED PROJECT/REGION ## ## UNIFIED PROJECT/REGION ##
region_name: Optional[str] = None region_name: Optional[str] = None
## VERTEX AI ## ## VERTEX AI ##
@ -310,6 +311,9 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
stream_timeout: Optional[Union[float, str]] stream_timeout: Optional[Union[float, str]]
max_retries: Optional[int] max_retries: Optional[int]
organization: Optional[Union[List, str]] # for openai orgs organization: Optional[Union[List, str]] # for openai orgs
configurable_clientside_auth_params: Optional[
List[str]
] # for allowing api base switching on finetuned models
## DROP PARAMS ## ## DROP PARAMS ##
drop_params: Optional[bool] drop_params: Optional[bool]
## UNIFIED PROJECT/REGION ## ## UNIFIED PROJECT/REGION ##
@ -487,6 +491,7 @@ class ModelGroupInfo(BaseModel):
supports_vision: bool = Field(default=False) supports_vision: bool = Field(default=False)
supports_function_calling: bool = Field(default=False) supports_function_calling: bool = Field(default=False)
supported_openai_params: Optional[List[str]] = Field(default=[]) supported_openai_params: Optional[List[str]] = Field(default=[])
configurable_clientside_auth_params: Optional[List[str]] = None
class AssistantsTypedDict(TypedDict): class AssistantsTypedDict(TypedDict):

View file

@ -1196,6 +1196,7 @@ all_litellm_params = [
"client_id", "client_id",
"client_secret", "client_secret",
"user_continue_message", "user_continue_message",
"configurable_clientside_auth_params",
] ]
@ -1323,7 +1324,7 @@ class StandardLoggingPayload(TypedDict):
metadata: StandardLoggingMetadata metadata: StandardLoggingMetadata
cache_hit: Optional[bool] cache_hit: Optional[bool]
cache_key: Optional[str] cache_key: Optional[str]
saved_cache_cost: Optional[float] saved_cache_cost: float
request_tags: list request_tags: list
end_user: Optional[str] end_user: Optional[str]
requester_ip_address: Optional[str] requester_ip_address: Optional[str]

View file

@ -3862,9 +3862,9 @@
"supports_vision": true "supports_vision": true
}, },
"anthropic.claude-3-5-sonnet-20240620-v1:0": { "anthropic.claude-3-5-sonnet-20240620-v1:0": {
"max_tokens": 8192, "max_tokens": 4096,
"max_input_tokens": 200000, "max_input_tokens": 200000,
"max_output_tokens": 8192, "max_output_tokens": 4096,
"input_cost_per_token": 0.000003, "input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015, "output_cost_per_token": 0.000015,
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
@ -3906,9 +3906,9 @@
"supports_vision": true "supports_vision": true
}, },
"us.anthropic.claude-3-5-sonnet-20240620-v1:0": { "us.anthropic.claude-3-5-sonnet-20240620-v1:0": {
"max_tokens": 8192, "max_tokens": 4096,
"max_input_tokens": 200000, "max_input_tokens": 200000,
"max_output_tokens": 8192, "max_output_tokens": 4096,
"input_cost_per_token": 0.000003, "input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015, "output_cost_per_token": 0.000015,
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
@ -3939,9 +3939,9 @@
"supports_vision": true "supports_vision": true
}, },
"eu.anthropic.claude-3-sonnet-20240229-v1:0": { "eu.anthropic.claude-3-sonnet-20240229-v1:0": {
"max_tokens": 8192, "max_tokens": 4096,
"max_input_tokens": 200000, "max_input_tokens": 200000,
"max_output_tokens": 8192, "max_output_tokens": 4096,
"input_cost_per_token": 0.000003, "input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015, "output_cost_per_token": 0.000015,
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
@ -3950,9 +3950,9 @@
"supports_vision": true "supports_vision": true
}, },
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0": { "eu.anthropic.claude-3-5-sonnet-20240620-v1:0": {
"max_tokens": 8192, "max_tokens": 4096,
"max_input_tokens": 200000, "max_input_tokens": 200000,
"max_output_tokens": 8192, "max_output_tokens": 4096,
"input_cost_per_token": 0.000003, "input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015, "output_cost_per_token": 0.000015,
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
@ -5593,6 +5593,11 @@
"output_cost_per_token": 0.0000012, "output_cost_per_token": 0.0000012,
"litellm_provider": "fireworks_ai" "litellm_provider": "fireworks_ai"
}, },
"fireworks-ai-default": {
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "fireworks_ai"
},
"fireworks-ai-embedding-up-to-150m": { "fireworks-ai-embedding-up-to-150m": {
"input_cost_per_token": 0.000000008, "input_cost_per_token": 0.000000008,
"output_cost_per_token": 0.000000, "output_cost_per_token": 0.000000,