LiteLLM Minor Fixes & Improvements (09/23/2024) (#5842) (#5858)

* LiteLLM Minor Fixes & Improvements (09/23/2024)  (#5842)

* feat(auth_utils.py): enable admin to allow client-side credentials to be passed

Makes it easier for devs to experiment with finetuned fireworks ai models

* feat(router.py): allow setting configurable_clientside_auth_params for a model

Closes https://github.com/BerriAI/litellm/issues/5843

* build(model_prices_and_context_window.json): fix anthropic claude-3-5-sonnet max output token limit

Fixes https://github.com/BerriAI/litellm/issues/5850

* fix(azure_ai/): support content list for azure ai

Fixes https://github.com/BerriAI/litellm/issues/4237

* fix(litellm_logging.py): always set saved_cache_cost

Set to 0 by default

* fix(fireworks_ai/cost_calculator.py): add fireworks ai default pricing

handles calling 405b+ size models

* fix(slack_alerting.py): fix error alerting for failed spend tracking

Fixes regression with slack alerting error monitoring

* fix(vertex_and_google_ai_studio_gemini.py): handle gemini no candidates in streaming chunk error

* docs(bedrock.md): add llama3-1 models

* test: fix tests

* fix(azure_ai/chat): fix transformation for azure ai calls
This commit is contained in:
Krish Dholakia 2024-09-24 15:01:31 -07:00 committed by GitHub
parent 4df9aca45e
commit d37c8b5c6b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 611 additions and 294 deletions

View file

@ -987,6 +987,9 @@ Here's an example of using a bedrock model with LiteLLM. For a complete list, re
| Anthropic Claude-V2.1 | `completion(model='bedrock/anthropic.claude-v2:1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Anthropic Claude-V2 | `completion(model='bedrock/anthropic.claude-v2', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Anthropic Claude-Instant V1 | `completion(model='bedrock/anthropic.claude-instant-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-1-405b | `completion(model='bedrock/meta.llama3-1-405b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-1-70b | `completion(model='bedrock/meta.llama3-1-70b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-1-8b | `completion(model='bedrock/meta.llama3-1-8b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-70b | `completion(model='bedrock/meta.llama3-70b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Meta llama3-8b | `completion(model='bedrock/meta.llama3-8b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
| Amazon Titan Lite | `completion(model='bedrock/amazon.titan-text-lite-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |

View file

@ -963,8 +963,8 @@ from .llms.OpenAI.openai import (
MistralEmbeddingConfig,
DeepInfraConfig,
GroqConfig,
AzureAIStudioConfig,
)
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
from .llms.mistral.mistral_chat_transformation import MistralConfig
from .llms.OpenAI.chat.o1_transformation import (
OpenAIO1Config,

View file

@ -10,7 +10,7 @@ import traceback
from datetime import datetime as dt
from datetime import timedelta, timezone
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Set, TypedDict, Union
from typing import Any, Dict, List, Literal, Optional, Set, TypedDict, Union, get_args
import aiohttp
import dotenv
@ -57,20 +57,7 @@ class SlackAlerting(CustomBatchLogger):
float
] = None, # threshold for slow / hanging llm responses (in seconds)
alerting: Optional[List] = [],
alert_types: List[AlertType] = [
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
"daily_reports",
"spend_reports",
"fallback_reports",
"cooldown_deployment",
"new_model_added",
"outage_alerts",
"failed_tracking_spend",
],
alert_types: List[AlertType] = list(get_args(AlertType)),
alert_to_webhook_url: Optional[
Dict[AlertType, Union[List[str], str]]
] = None, # if user wants to separate alerts to diff channels
@ -613,7 +600,7 @@ class SlackAlerting(CustomBatchLogger):
await self.send_alert(
message=message,
level="High",
alert_type="budget_alerts",
alert_type="failed_tracking_spend",
alerting_metadata={},
)
await _cache.async_set_cache(

View file

@ -2498,14 +2498,17 @@ def get_standard_logging_object_payload(
else:
cache_key = None
saved_cache_cost: Optional[float] = None
saved_cache_cost: float = 0.0
if cache_hit is True:
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
saved_cache_cost = logging_obj._response_cost_calculator(
saved_cache_cost = (
logging_obj._response_cost_calculator(
result=init_response_obj, cache_hit=False # type: ignore
)
or 0.0
)
## Get model cost information ##
base_model = _get_base_model_from_metadata(model_call_details=kwargs)

View file

@ -103,25 +103,6 @@ class MistralEmbeddingConfig:
return optional_params
class AzureAIStudioConfig:
def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description"""
return [
ProviderField(
field_name="api_key",
field_type="string",
field_description="Your Azure AI Studio API Key.",
field_value="zEJ...",
),
ProviderField(
field_name="api_base",
field_type="string",
field_description="Your Azure AI Studio API Base.",
field_value="https://Mistral-serverless.",
),
]
class DeepInfraConfig:
"""
Reference: https://deepinfra.com/docs/advanced/openai_api

View file

@ -0,0 +1,59 @@
from typing import Any, Callable, List, Optional, Union
from httpx._config import Timeout
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper
from .transformation import AzureAIStudioConfig
class AzureAIChatCompletion(OpenAIChatCompletion):
def completion(
self,
model_response: ModelResponse,
timeout: Union[float, Timeout],
optional_params: dict,
logging_obj: Any,
model: Optional[str] = None,
messages: Optional[list] = None,
print_verbose: Optional[Callable[..., Any]] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
acompletion: bool = False,
litellm_params=None,
logger_fn=None,
headers: Optional[dict] = None,
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
drop_params: Optional[bool] = None,
):
transformed_messages = AzureAIStudioConfig()._transform_messages(
messages=messages # type: ignore
)
return super().completion(
model_response,
timeout,
optional_params,
logging_obj,
model,
transformed_messages,
print_verbose,
api_key,
api_base,
acompletion,
litellm_params,
logger_fn,
headers,
custom_prompt_dict,
client,
organization,
custom_llm_provider,
drop_params,
)

View file

@ -0,0 +1,31 @@
from typing import List
from litellm.llms.OpenAI.openai import OpenAIConfig
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ProviderField
class AzureAIStudioConfig(OpenAIConfig):
def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description"""
return [
ProviderField(
field_name="api_key",
field_type="string",
field_description="Your Azure AI Studio API Key.",
field_value="zEJ...",
),
ProviderField(
field_name="api_base",
field_type="string",
field_description="Your Azure AI Studio API Base.",
field_value="https://Mistral-serverless.",
),
]
def _transform_messages(self, messages: List[AllMessageValues]) -> List:
for message in messages:
message = convert_content_list_to_str(message=message)
return messages

View file

@ -10,7 +10,7 @@ from litellm.utils import get_model_info
# Extract the number of billion parameters from the model name
# only used for together_computer LLMs
def get_model_params_and_category(model_name: str) -> str:
def get_base_model_for_pricing(model_name: str) -> str:
"""
Helper function for calculating together ai pricing.
@ -43,7 +43,7 @@ def get_model_params_and_category(model_name: str) -> str:
return "fireworks-ai-16b-80b"
# If no matches, return the original model_name
return model_name
return "fireworks-ai-default"
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
@ -57,10 +57,16 @@ def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
"""
base_model = get_model_params_and_category(model_name=model)
## check if model mapped, else use default pricing
try:
model_info = get_model_info(model=model, custom_llm_provider="fireworks_ai")
except Exception:
base_model = get_base_model_for_pricing(model_name=model)
## GET MODEL INFO
model_info = get_model_info(model=base_model, custom_llm_provider="fireworks_ai")
model_info = get_model_info(
model=base_model, custom_llm_provider="fireworks_ai"
)
## CALCULATE INPUT COST

View file

@ -0,0 +1,32 @@
"""
Common utility functions used for translating messages across providers
"""
from typing import List
from litellm.types.llms.openai import AllMessageValues
def convert_content_list_to_str(message: AllMessageValues) -> AllMessageValues:
"""
- handles scenario where content is list and not string
- content list is just text, and no images
- if image passed in, then just return as is (user-intended)
Motivation: mistral api + azure ai don't support content as a list
"""
texts = ""
message_content = message.get("content")
if message_content:
if message_content is not None and isinstance(message_content, list):
for c in message_content:
text_content = c.get("text")
if text_content:
texts += text_content
elif message_content is not None and isinstance(message_content, str):
texts = message_content
if texts:
message["content"] = texts
return message

View file

@ -49,6 +49,7 @@ from litellm.types.llms.openai import (
ChatCompletionUsageBlock,
)
from litellm.types.llms.vertex_ai import (
Candidates,
ContentType,
FunctionCallingConfig,
FunctionDeclaration,
@ -187,7 +188,11 @@ class VertexAIConfig:
optional_params["stop_sequences"] = value
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_output_tokens"] = value
if param == "response_format" and value["type"] == "json_object":
if (
param == "response_format"
and isinstance(value, dict)
and value["type"] == "json_object"
):
optional_params["response_mime_type"] = "application/json"
if param == "frequency_penalty":
optional_params["frequency_penalty"] = value
@ -900,14 +905,14 @@ class VertexLLM(VertexBase):
return model_response
if len(completion_response["candidates"]) > 0:
_candidates = completion_response.get("candidates")
if _candidates and len(_candidates) > 0:
content_policy_violations = (
VertexGeminiConfig().get_flagged_finish_reasons()
)
if (
"finishReason" in completion_response["candidates"][0]
and completion_response["candidates"][0]["finishReason"]
in content_policy_violations.keys()
"finishReason" in _candidates[0]
and _candidates[0]["finishReason"] in content_policy_violations.keys()
):
## CONTENT POLICY VIOLATION ERROR
model_response.choices[0].finish_reason = "content_filter"
@ -956,12 +961,13 @@ class VertexLLM(VertexBase):
content_str = ""
tools: List[ChatCompletionToolCallChunk] = []
functions: Optional[ChatCompletionToolCallFunctionChunk] = None
for idx, candidate in enumerate(completion_response["candidates"]):
if _candidates:
for idx, candidate in enumerate(_candidates):
if "content" not in candidate:
continue
if "groundingMetadata" in candidate:
grounding_metadata.append(candidate["groundingMetadata"])
grounding_metadata.append(candidate["groundingMetadata"]) # type: ignore
if "safetyRatings" in candidate:
safety_ratings.append(candidate["safetyRatings"])
@ -973,7 +979,9 @@ class VertexLLM(VertexBase):
if "functionCall" in candidate["content"]["parts"][0]:
_function_chunk = ChatCompletionToolCallFunctionChunk(
name=candidate["content"]["parts"][0]["functionCall"]["name"],
name=candidate["content"]["parts"][0]["functionCall"][
"name"
],
arguments=json.dumps(
candidate["content"]["parts"][0]["functionCall"]["args"]
),
@ -1433,10 +1441,12 @@ class ModelResponseIterator:
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
_candidates: Optional[List[Candidates]] = processed_chunk.get("candidates")
gemini_chunk: Optional[Candidates] = None
if _candidates and len(_candidates) > 0:
gemini_chunk = _candidates[0]
gemini_chunk = processed_chunk["candidates"][0]
if "content" in gemini_chunk:
if gemini_chunk and "content" in gemini_chunk:
if "text" in gemini_chunk["content"]["parts"][0]:
text = gemini_chunk["content"]["parts"][0]["text"]
elif "functionCall" in gemini_chunk["content"]["parts"][0]:
@ -1455,7 +1465,7 @@ class ModelResponseIterator:
index=0,
)
if "finishReason" in gemini_chunk:
if gemini_chunk and "finishReason" in gemini_chunk:
finish_reason = map_finish_reason(
finish_reason=gemini_chunk["finishReason"]
)
@ -1533,6 +1543,7 @@ class ModelResponseIterator:
)
def _common_chunk_parsing_logic(self, chunk: str) -> GenericStreamingChunk:
try:
chunk = chunk.replace("data:", "")
if len(chunk) > 0:
"""
@ -1544,7 +1555,7 @@ class ModelResponseIterator:
return self.handle_valid_json_chunk(chunk=chunk)
elif self.chunk_type == "accumulated_json":
return self.handle_accumulated_json_chunk(chunk=chunk)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
@ -1553,6 +1564,8 @@ class ModelResponseIterator:
index=0,
tool_use=None,
)
except Exception:
raise
def __next__(self):
try:

View file

@ -83,6 +83,7 @@ from .llms import (
from .llms.AI21 import completion as ai21
from .llms.anthropic.chat import AnthropicChatCompletion
from .llms.anthropic.completion import AnthropicTextCompletion
from .llms.azure_ai.chat.handler import AzureAIChatCompletion
from .llms.azure_text import AzureTextCompletion
from .llms.AzureOpenAI.audio_transcriptions import AzureAudioTranscription
from .llms.AzureOpenAI.azure import AzureChatCompletion, _check_dynamic_azure_params
@ -166,6 +167,7 @@ openai_text_completions = OpenAITextCompletion()
openai_o1_chat_completions = OpenAIO1ChatCompletion()
openai_audio_transcriptions = OpenAIAudioTranscription()
databricks_chat_completions = DatabricksChatCompletion()
azure_ai_chat_completions = AzureAIChatCompletion()
anthropic_chat_completions = AnthropicChatCompletion()
anthropic_text_completions = AnthropicTextCompletion()
azure_chat_completions = AzureChatCompletion()
@ -1177,7 +1179,7 @@ def completion(
headers = headers or litellm.headers
## LOAD CONFIG - if set
config = litellm.OpenAIConfig.get_config()
config = litellm.AzureAIStudioConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
@ -1190,7 +1192,7 @@ def completion(
## COMPLETION CALL
try:
response = openai_chat_completions.completion(
response = azure_ai_chat_completions.completion(
model=model,
messages=messages,
headers=headers,

View file

@ -3862,9 +3862,9 @@
"supports_vision": true
},
"anthropic.claude-3-5-sonnet-20240620-v1:0": {
"max_tokens": 8192,
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
@ -3906,9 +3906,9 @@
"supports_vision": true
},
"us.anthropic.claude-3-5-sonnet-20240620-v1:0": {
"max_tokens": 8192,
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
@ -3939,9 +3939,9 @@
"supports_vision": true
},
"eu.anthropic.claude-3-sonnet-20240229-v1:0": {
"max_tokens": 8192,
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
@ -3950,9 +3950,9 @@
"supports_vision": true
},
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0": {
"max_tokens": 8192,
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
@ -5593,6 +5593,11 @@
"output_cost_per_token": 0.0000012,
"litellm_provider": "fireworks_ai"
},
"fireworks-ai-default": {
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "fireworks_ai"
},
"fireworks-ai-embedding-up-to-150m": {
"input_cost_per_token": 0.000000008,
"output_cost_per_token": 0.000000,

View file

@ -31,6 +31,13 @@ model_list:
- model_name: "anthropic/*"
litellm_params:
model: "anthropic/*"
- model_name: "openai/*"
litellm_params:
model: "openai/*"
- model_name: "fireworks_ai/*"
litellm_params:
model: "fireworks_ai/*"
configurable_clientside_auth_params: ["api_base"]
litellm_settings:

View file

@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
from fastapi import HTTPException, Request, status
from litellm import Router, provider_list
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
@ -72,7 +73,41 @@ def check_complete_credentials(request_body: dict) -> bool:
return False
def is_request_body_safe(request_body: dict) -> bool:
def _allow_model_level_clientside_configurable_parameters(
model: str, param: str, llm_router: Optional[Router]
) -> bool:
"""
Check if model is allowed to use configurable client-side params
- get matching model
- check if 'clientside_configurable_parameters' is set for model
-
"""
if llm_router is None:
return False
# check if model is set
model_info = llm_router.get_model_group_info(model_group=model)
if model_info is None:
# check if wildcard model is set
if model.split("/", 1)[0] in provider_list:
model_info = llm_router.get_model_group_info(
model_group=model.split("/", 1)[0]
)
if model_info is None:
return False
if model_info is None or model_info.configurable_clientside_auth_params is None:
return False
if param in model_info.configurable_clientside_auth_params:
return True
return False
def is_request_body_safe(
request_body: dict, general_settings: dict, llm_router: Optional[Router], model: str
) -> bool:
"""
Check if the request body is safe.
@ -88,7 +123,20 @@ def is_request_body_safe(request_body: dict) -> bool:
request_body=request_body
)
):
raise ValueError(f"BadRequest: {param} is not allowed in request body")
if general_settings.get("allow_client_side_credentials") is True:
return True
elif (
_allow_model_level_clientside_configurable_parameters(
model=model, param=param, llm_router=llm_router
)
is True
):
return True
raise ValueError(
f"Rejected Request: {param} is not allowed in request body. "
"Enable with `general_settings::allow_client_side_credentials` on proxy config.yaml. "
"Relevant Issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997",
)
return True
@ -110,13 +158,20 @@ async def pre_db_read_auth_checks(
Raises:
- HTTPException if request fails initial auth checks
"""
from litellm.proxy.proxy_server import general_settings, premium_user
from litellm.proxy.proxy_server import general_settings, llm_router, premium_user
# Check 1. request size
await check_if_request_size_is_safe(request=request)
# Check 2. Request body is safe
is_request_body_safe(request_body=request_data)
is_request_body_safe(
request_body=request_data,
general_settings=general_settings,
llm_router=llm_router,
model=request_data.get(
"model", ""
), # [TODO] use model passed in url as well (azure openai routes)
)
# Check 3. Check if IP address is allowed
is_valid_ip, passed_in_ip = _check_valid_ip(

View file

@ -66,7 +66,7 @@ async def route_request(
"""
router_model_names = llm_router.model_names if llm_router is not None else []
if "api_key" in data:
if "api_key" in data or "api_base" in data:
return getattr(litellm, f"{route_type}")(**data)
elif "user_config" in data:

View file

@ -4,6 +4,8 @@ import secrets
import traceback
from typing import Optional
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
@ -105,6 +107,8 @@ def get_logging_payload(
additional_usage_values = {}
for k, v in usage.items():
if k not in special_usage_fields:
if isinstance(v, BaseModel):
v = v.model_dump()
additional_usage_values.update({k: v})
clean_metadata["additional_usage_values"] = additional_usage_values

View file

@ -14,7 +14,17 @@ from datetime import datetime, timedelta
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from functools import wraps
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union, overload
from typing import (
TYPE_CHECKING,
Any,
List,
Literal,
Optional,
Tuple,
Union,
get_args,
overload,
)
import backoff
import httpx
@ -222,19 +232,7 @@ class ProxyLogging:
self.cache_control_check = _PROXY_CacheControlCheck()
self.alerting: Optional[List] = None
self.alerting_threshold: float = 300 # default to 5 min. threshold
self.alert_types: List[AlertType] = [
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
"daily_reports",
"spend_reports",
"fallback_reports",
"cooldown_deployment",
"new_model_added",
"outage_alerts",
]
self.alert_types: List[AlertType] = list(get_args(AlertType))
self.alert_to_webhook_url: Optional[dict] = None
self.slack_alerting_instance: SlackAlerting = SlackAlerting(
alerting_threshold=self.alerting_threshold,

View file

@ -4335,11 +4335,28 @@ class Router:
total_tpm: Optional[int] = None
total_rpm: Optional[int] = None
configurable_clientside_auth_params: Optional[List[str]] = None
for model in self.model_list:
if "model_name" in model and model["model_name"] == model_group:
is_match = False
if (
"model_name" in model and model["model_name"] == model_group
): # exact match
is_match = True
elif (
"model_name" in model
and model_group in self.provider_default_deployments
): # wildcard model
is_match = True
if not is_match:
continue
# model in model group found #
litellm_params = LiteLLM_Params(**model["litellm_params"])
# get configurable clientside auth params
configurable_clientside_auth_params = (
litellm_params.configurable_clientside_auth_params
)
# get model tpm
_deployment_tpm: Optional[int] = None
if _deployment_tpm is None:
@ -4425,9 +4442,7 @@ class Router:
> model_group_info.max_input_tokens
)
):
model_group_info.max_input_tokens = model_info[
"max_input_tokens"
]
model_group_info.max_input_tokens = model_info["max_input_tokens"]
if (
model_info.get("max_output_tokens", None) is not None
and model_info["max_output_tokens"] is not None
@ -4437,9 +4452,7 @@ class Router:
> model_group_info.max_output_tokens
)
):
model_group_info.max_output_tokens = model_info[
"max_output_tokens"
]
model_group_info.max_output_tokens = model_info["max_output_tokens"]
if model_info.get("input_cost_per_token", None) is not None and (
model_group_info.input_cost_per_token is None
or model_info["input_cost_per_token"]
@ -4480,13 +4493,20 @@ class Router:
"supported_openai_params"
]
if model_group_info is not None:
## UPDATE WITH TOTAL TPM/RPM FOR MODEL GROUP
if total_tpm is not None and model_group_info is not None:
if total_tpm is not None:
model_group_info.tpm = total_tpm
if total_rpm is not None and model_group_info is not None:
if total_rpm is not None:
model_group_info.rpm = total_rpm
## UPDATE WITH CONFIGURABLE CLIENTSIDE AUTH PARAMS FOR MODEL GROUP
if configurable_clientside_auth_params is not None:
model_group_info.configurable_clientside_auth_params = (
configurable_clientside_auth_params
)
return model_group_info
def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:

View file

@ -141,9 +141,16 @@ def test_completion_azure_ai_command_r():
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "")
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "")
response: litellm.ModelResponse = completion(
response = completion(
model="azure_ai/command-r-plus",
messages=[{"role": "user", "content": "What is the meaning of life?"}],
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What is the meaning of life?"}
],
}
],
) # type: ignore
assert "azure_ai" in response.model

View file

@ -1257,14 +1257,31 @@ def test_completion_cost_databricks_embedding(model):
cost = completion_cost(completion_response=resp)
def test_completion_cost_fireworks_ai():
from litellm.llms.fireworks_ai.cost_calculator import get_base_model_for_pricing
@pytest.mark.parametrize(
"model, base_model",
[
("fireworks_ai/llama-v3p1-405b-instruct", "fireworks-ai-default"),
("fireworks_ai/mixtral-8x7b-instruct", "fireworks-ai-moe-up-to-56b"),
],
)
def test_get_model_params_fireworks_ai(model, base_model):
pricing_model = get_base_model_for_pricing(model_name=model)
assert base_model == pricing_model
@pytest.mark.parametrize(
"model",
["fireworks_ai/llama-v3p1-405b-instruct", "fireworks_ai/mixtral-8x7b-instruct"],
)
def test_completion_cost_fireworks_ai(model):
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
messages = [{"role": "user", "content": "Hey, how's it going?"}]
resp = litellm.completion(
model="fireworks_ai/mixtral-8x7b-instruct", messages=messages
) # works fine
resp = litellm.completion(model=model, messages=messages) # works fine
print(resp)
cost = completion_cost(completion_response=resp)

View file

@ -12,6 +12,7 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.auth.auth_utils import is_request_body_safe
from litellm.proxy.litellm_pre_call_utils import (
_get_dynamic_logging_metadata,
add_litellm_data_to_request,
@ -291,3 +292,78 @@ def test_dynamic_logging_metadata_key_and_team_metadata(callback_vars):
for var in callbacks.callback_vars.values():
assert "os.environ" not in var
@pytest.mark.parametrize(
"allow_client_side_credentials, expect_error", [(True, False), (False, True)]
)
def test_is_request_body_safe_global_enabled(
allow_client_side_credentials, expect_error
):
from litellm import Router
error_raised = False
llm_router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
}
]
)
try:
is_request_body_safe(
request_body={"api_base": "hello-world"},
general_settings={
"allow_client_side_credentials": allow_client_side_credentials
},
llm_router=llm_router,
model="gpt-3.5-turbo",
)
except Exception as e:
print(e)
error_raised = True
assert expect_error == error_raised
@pytest.mark.parametrize(
"allow_client_side_credentials, expect_error", [(True, False), (False, True)]
)
def test_is_request_body_safe_model_enabled(
allow_client_side_credentials, expect_error
):
from litellm import Router
error_raised = False
llm_router = Router(
model_list=[
{
"model_name": "fireworks_ai/*",
"litellm_params": {
"model": "fireworks_ai/*",
"api_key": os.getenv("FIREWORKS_API_KEY"),
"configurable_clientside_auth_params": (
["api_base"] if allow_client_side_credentials else []
),
},
}
]
)
try:
is_request_body_safe(
request_body={"api_base": "hello-world"},
general_settings={},
llm_router=llm_router,
model="fireworks_ai/my-new-model",
)
except Exception as e:
print(e)
error_raised = True
assert expect_error == error_raised

View file

@ -283,7 +283,7 @@ class PromptFeedback(TypedDict):
class GenerateContentResponseBody(TypedDict, total=False):
candidates: Required[List[Candidates]]
candidates: List[Candidates]
promptFeedback: PromptFeedback
usageMetadata: Required[UsageMetadata]

View file

@ -139,6 +139,7 @@ class GenericLiteLLMParams(BaseModel):
)
max_retries: Optional[int] = None
organization: Optional[str] = None # for openai orgs
configurable_clientside_auth_params: Optional[List[str]] = None
## UNIFIED PROJECT/REGION ##
region_name: Optional[str] = None
## VERTEX AI ##
@ -310,6 +311,9 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
stream_timeout: Optional[Union[float, str]]
max_retries: Optional[int]
organization: Optional[Union[List, str]] # for openai orgs
configurable_clientside_auth_params: Optional[
List[str]
] # for allowing api base switching on finetuned models
## DROP PARAMS ##
drop_params: Optional[bool]
## UNIFIED PROJECT/REGION ##
@ -487,6 +491,7 @@ class ModelGroupInfo(BaseModel):
supports_vision: bool = Field(default=False)
supports_function_calling: bool = Field(default=False)
supported_openai_params: Optional[List[str]] = Field(default=[])
configurable_clientside_auth_params: Optional[List[str]] = None
class AssistantsTypedDict(TypedDict):

View file

@ -1196,6 +1196,7 @@ all_litellm_params = [
"client_id",
"client_secret",
"user_continue_message",
"configurable_clientside_auth_params",
]
@ -1323,7 +1324,7 @@ class StandardLoggingPayload(TypedDict):
metadata: StandardLoggingMetadata
cache_hit: Optional[bool]
cache_key: Optional[str]
saved_cache_cost: Optional[float]
saved_cache_cost: float
request_tags: list
end_user: Optional[str]
requester_ip_address: Optional[str]

View file

@ -3862,9 +3862,9 @@
"supports_vision": true
},
"anthropic.claude-3-5-sonnet-20240620-v1:0": {
"max_tokens": 8192,
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
@ -3906,9 +3906,9 @@
"supports_vision": true
},
"us.anthropic.claude-3-5-sonnet-20240620-v1:0": {
"max_tokens": 8192,
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
@ -3939,9 +3939,9 @@
"supports_vision": true
},
"eu.anthropic.claude-3-sonnet-20240229-v1:0": {
"max_tokens": 8192,
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
@ -3950,9 +3950,9 @@
"supports_vision": true
},
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0": {
"max_tokens": 8192,
"max_tokens": 4096,
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
@ -5593,6 +5593,11 @@
"output_cost_per_token": 0.0000012,
"litellm_provider": "fireworks_ai"
},
"fireworks-ai-default": {
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "fireworks_ai"
},
"fireworks-ai-embedding-up-to-150m": {
"input_cost_per_token": 0.000000008,
"output_cost_per_token": 0.000000,