forked from phoenix/litellm-mirror
* LiteLLM Minor Fixes & Improvements (09/23/2024) (#5842) * feat(auth_utils.py): enable admin to allow client-side credentials to be passed Makes it easier for devs to experiment with finetuned fireworks ai models * feat(router.py): allow setting configurable_clientside_auth_params for a model Closes https://github.com/BerriAI/litellm/issues/5843 * build(model_prices_and_context_window.json): fix anthropic claude-3-5-sonnet max output token limit Fixes https://github.com/BerriAI/litellm/issues/5850 * fix(azure_ai/): support content list for azure ai Fixes https://github.com/BerriAI/litellm/issues/4237 * fix(litellm_logging.py): always set saved_cache_cost Set to 0 by default * fix(fireworks_ai/cost_calculator.py): add fireworks ai default pricing handles calling 405b+ size models * fix(slack_alerting.py): fix error alerting for failed spend tracking Fixes regression with slack alerting error monitoring * fix(vertex_and_google_ai_studio_gemini.py): handle gemini no candidates in streaming chunk error * docs(bedrock.md): add llama3-1 models * test: fix tests * fix(azure_ai/chat): fix transformation for azure ai calls
This commit is contained in:
parent
4df9aca45e
commit
d37c8b5c6b
25 changed files with 611 additions and 294 deletions
|
@ -987,6 +987,9 @@ Here's an example of using a bedrock model with LiteLLM. For a complete list, re
|
|||
| Anthropic Claude-V2.1 | `completion(model='bedrock/anthropic.claude-v2:1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Anthropic Claude-V2 | `completion(model='bedrock/anthropic.claude-v2', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Anthropic Claude-Instant V1 | `completion(model='bedrock/anthropic.claude-instant-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Meta llama3-1-405b | `completion(model='bedrock/meta.llama3-1-405b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Meta llama3-1-70b | `completion(model='bedrock/meta.llama3-1-70b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Meta llama3-1-8b | `completion(model='bedrock/meta.llama3-1-8b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Meta llama3-70b | `completion(model='bedrock/meta.llama3-70b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Meta llama3-8b | `completion(model='bedrock/meta.llama3-8b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||
| Amazon Titan Lite | `completion(model='bedrock/amazon.titan-text-lite-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||
|
|
|
@ -963,8 +963,8 @@ from .llms.OpenAI.openai import (
|
|||
MistralEmbeddingConfig,
|
||||
DeepInfraConfig,
|
||||
GroqConfig,
|
||||
AzureAIStudioConfig,
|
||||
)
|
||||
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
||||
from .llms.mistral.mistral_chat_transformation import MistralConfig
|
||||
from .llms.OpenAI.chat.o1_transformation import (
|
||||
OpenAIO1Config,
|
||||
|
|
|
@ -10,7 +10,7 @@ import traceback
|
|||
from datetime import datetime as dt
|
||||
from datetime import timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, TypedDict, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, TypedDict, Union, get_args
|
||||
|
||||
import aiohttp
|
||||
import dotenv
|
||||
|
@ -57,20 +57,7 @@ class SlackAlerting(CustomBatchLogger):
|
|||
float
|
||||
] = None, # threshold for slow / hanging llm responses (in seconds)
|
||||
alerting: Optional[List] = [],
|
||||
alert_types: List[AlertType] = [
|
||||
"llm_exceptions",
|
||||
"llm_too_slow",
|
||||
"llm_requests_hanging",
|
||||
"budget_alerts",
|
||||
"db_exceptions",
|
||||
"daily_reports",
|
||||
"spend_reports",
|
||||
"fallback_reports",
|
||||
"cooldown_deployment",
|
||||
"new_model_added",
|
||||
"outage_alerts",
|
||||
"failed_tracking_spend",
|
||||
],
|
||||
alert_types: List[AlertType] = list(get_args(AlertType)),
|
||||
alert_to_webhook_url: Optional[
|
||||
Dict[AlertType, Union[List[str], str]]
|
||||
] = None, # if user wants to separate alerts to diff channels
|
||||
|
@ -613,7 +600,7 @@ class SlackAlerting(CustomBatchLogger):
|
|||
await self.send_alert(
|
||||
message=message,
|
||||
level="High",
|
||||
alert_type="budget_alerts",
|
||||
alert_type="failed_tracking_spend",
|
||||
alerting_metadata={},
|
||||
)
|
||||
await _cache.async_set_cache(
|
||||
|
|
|
@ -2498,14 +2498,17 @@ def get_standard_logging_object_payload(
|
|||
else:
|
||||
cache_key = None
|
||||
|
||||
saved_cache_cost: Optional[float] = None
|
||||
saved_cache_cost: float = 0.0
|
||||
if cache_hit is True:
|
||||
|
||||
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
|
||||
|
||||
saved_cache_cost = logging_obj._response_cost_calculator(
|
||||
saved_cache_cost = (
|
||||
logging_obj._response_cost_calculator(
|
||||
result=init_response_obj, cache_hit=False # type: ignore
|
||||
)
|
||||
or 0.0
|
||||
)
|
||||
|
||||
## Get model cost information ##
|
||||
base_model = _get_base_model_from_metadata(model_call_details=kwargs)
|
||||
|
|
|
@ -103,25 +103,6 @@ class MistralEmbeddingConfig:
|
|||
return optional_params
|
||||
|
||||
|
||||
class AzureAIStudioConfig:
|
||||
def get_required_params(self) -> List[ProviderField]:
|
||||
"""For a given provider, return it's required fields with a description"""
|
||||
return [
|
||||
ProviderField(
|
||||
field_name="api_key",
|
||||
field_type="string",
|
||||
field_description="Your Azure AI Studio API Key.",
|
||||
field_value="zEJ...",
|
||||
),
|
||||
ProviderField(
|
||||
field_name="api_base",
|
||||
field_type="string",
|
||||
field_description="Your Azure AI Studio API Base.",
|
||||
field_value="https://Mistral-serverless.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class DeepInfraConfig:
|
||||
"""
|
||||
Reference: https://deepinfra.com/docs/advanced/openai_api
|
||||
|
|
59
litellm/llms/azure_ai/chat/handler.py
Normal file
59
litellm/llms/azure_ai/chat/handler.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
from httpx._config import Timeout
|
||||
|
||||
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
|
||||
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
from .transformation import AzureAIStudioConfig
|
||||
|
||||
|
||||
class AzureAIChatCompletion(OpenAIChatCompletion):
|
||||
def completion(
|
||||
self,
|
||||
model_response: ModelResponse,
|
||||
timeout: Union[float, Timeout],
|
||||
optional_params: dict,
|
||||
logging_obj: Any,
|
||||
model: Optional[str] = None,
|
||||
messages: Optional[list] = None,
|
||||
print_verbose: Optional[Callable[..., Any]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
acompletion: bool = False,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers: Optional[dict] = None,
|
||||
custom_prompt_dict: dict = {},
|
||||
client=None,
|
||||
organization: Optional[str] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
drop_params: Optional[bool] = None,
|
||||
):
|
||||
|
||||
transformed_messages = AzureAIStudioConfig()._transform_messages(
|
||||
messages=messages # type: ignore
|
||||
)
|
||||
|
||||
return super().completion(
|
||||
model_response,
|
||||
timeout,
|
||||
optional_params,
|
||||
logging_obj,
|
||||
model,
|
||||
transformed_messages,
|
||||
print_verbose,
|
||||
api_key,
|
||||
api_base,
|
||||
acompletion,
|
||||
litellm_params,
|
||||
logger_fn,
|
||||
headers,
|
||||
custom_prompt_dict,
|
||||
client,
|
||||
organization,
|
||||
custom_llm_provider,
|
||||
drop_params,
|
||||
)
|
31
litellm/llms/azure_ai/chat/transformation.py
Normal file
31
litellm/llms/azure_ai/chat/transformation.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
from typing import List
|
||||
|
||||
from litellm.llms.OpenAI.openai import OpenAIConfig
|
||||
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ProviderField
|
||||
|
||||
|
||||
class AzureAIStudioConfig(OpenAIConfig):
|
||||
def get_required_params(self) -> List[ProviderField]:
|
||||
"""For a given provider, return it's required fields with a description"""
|
||||
return [
|
||||
ProviderField(
|
||||
field_name="api_key",
|
||||
field_type="string",
|
||||
field_description="Your Azure AI Studio API Key.",
|
||||
field_value="zEJ...",
|
||||
),
|
||||
ProviderField(
|
||||
field_name="api_base",
|
||||
field_type="string",
|
||||
field_description="Your Azure AI Studio API Base.",
|
||||
field_value="https://Mistral-serverless.",
|
||||
),
|
||||
]
|
||||
|
||||
def _transform_messages(self, messages: List[AllMessageValues]) -> List:
|
||||
for message in messages:
|
||||
message = convert_content_list_to_str(message=message)
|
||||
|
||||
return messages
|
|
@ -10,7 +10,7 @@ from litellm.utils import get_model_info
|
|||
|
||||
# Extract the number of billion parameters from the model name
|
||||
# only used for together_computer LLMs
|
||||
def get_model_params_and_category(model_name: str) -> str:
|
||||
def get_base_model_for_pricing(model_name: str) -> str:
|
||||
"""
|
||||
Helper function for calculating together ai pricing.
|
||||
|
||||
|
@ -43,7 +43,7 @@ def get_model_params_and_category(model_name: str) -> str:
|
|||
return "fireworks-ai-16b-80b"
|
||||
|
||||
# If no matches, return the original model_name
|
||||
return model_name
|
||||
return "fireworks-ai-default"
|
||||
|
||||
|
||||
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
|
||||
|
@ -57,10 +57,16 @@ def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
|
|||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
"""
|
||||
base_model = get_model_params_and_category(model_name=model)
|
||||
## check if model mapped, else use default pricing
|
||||
try:
|
||||
model_info = get_model_info(model=model, custom_llm_provider="fireworks_ai")
|
||||
except Exception:
|
||||
base_model = get_base_model_for_pricing(model_name=model)
|
||||
|
||||
## GET MODEL INFO
|
||||
model_info = get_model_info(model=base_model, custom_llm_provider="fireworks_ai")
|
||||
model_info = get_model_info(
|
||||
model=base_model, custom_llm_provider="fireworks_ai"
|
||||
)
|
||||
|
||||
## CALCULATE INPUT COST
|
||||
|
||||
|
|
32
litellm/llms/prompt_templates/common_utils.py
Normal file
32
litellm/llms/prompt_templates/common_utils.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
"""
|
||||
Common utility functions used for translating messages across providers
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
|
||||
def convert_content_list_to_str(message: AllMessageValues) -> AllMessageValues:
|
||||
"""
|
||||
- handles scenario where content is list and not string
|
||||
- content list is just text, and no images
|
||||
- if image passed in, then just return as is (user-intended)
|
||||
|
||||
Motivation: mistral api + azure ai don't support content as a list
|
||||
"""
|
||||
texts = ""
|
||||
message_content = message.get("content")
|
||||
if message_content:
|
||||
if message_content is not None and isinstance(message_content, list):
|
||||
for c in message_content:
|
||||
text_content = c.get("text")
|
||||
if text_content:
|
||||
texts += text_content
|
||||
elif message_content is not None and isinstance(message_content, str):
|
||||
texts = message_content
|
||||
|
||||
if texts:
|
||||
message["content"] = texts
|
||||
|
||||
return message
|
|
@ -49,6 +49,7 @@ from litellm.types.llms.openai import (
|
|||
ChatCompletionUsageBlock,
|
||||
)
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
Candidates,
|
||||
ContentType,
|
||||
FunctionCallingConfig,
|
||||
FunctionDeclaration,
|
||||
|
@ -187,7 +188,11 @@ class VertexAIConfig:
|
|||
optional_params["stop_sequences"] = value
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params["max_output_tokens"] = value
|
||||
if param == "response_format" and value["type"] == "json_object":
|
||||
if (
|
||||
param == "response_format"
|
||||
and isinstance(value, dict)
|
||||
and value["type"] == "json_object"
|
||||
):
|
||||
optional_params["response_mime_type"] = "application/json"
|
||||
if param == "frequency_penalty":
|
||||
optional_params["frequency_penalty"] = value
|
||||
|
@ -900,14 +905,14 @@ class VertexLLM(VertexBase):
|
|||
|
||||
return model_response
|
||||
|
||||
if len(completion_response["candidates"]) > 0:
|
||||
_candidates = completion_response.get("candidates")
|
||||
if _candidates and len(_candidates) > 0:
|
||||
content_policy_violations = (
|
||||
VertexGeminiConfig().get_flagged_finish_reasons()
|
||||
)
|
||||
if (
|
||||
"finishReason" in completion_response["candidates"][0]
|
||||
and completion_response["candidates"][0]["finishReason"]
|
||||
in content_policy_violations.keys()
|
||||
"finishReason" in _candidates[0]
|
||||
and _candidates[0]["finishReason"] in content_policy_violations.keys()
|
||||
):
|
||||
## CONTENT POLICY VIOLATION ERROR
|
||||
model_response.choices[0].finish_reason = "content_filter"
|
||||
|
@ -956,12 +961,13 @@ class VertexLLM(VertexBase):
|
|||
content_str = ""
|
||||
tools: List[ChatCompletionToolCallChunk] = []
|
||||
functions: Optional[ChatCompletionToolCallFunctionChunk] = None
|
||||
for idx, candidate in enumerate(completion_response["candidates"]):
|
||||
if _candidates:
|
||||
for idx, candidate in enumerate(_candidates):
|
||||
if "content" not in candidate:
|
||||
continue
|
||||
|
||||
if "groundingMetadata" in candidate:
|
||||
grounding_metadata.append(candidate["groundingMetadata"])
|
||||
grounding_metadata.append(candidate["groundingMetadata"]) # type: ignore
|
||||
|
||||
if "safetyRatings" in candidate:
|
||||
safety_ratings.append(candidate["safetyRatings"])
|
||||
|
@ -973,7 +979,9 @@ class VertexLLM(VertexBase):
|
|||
|
||||
if "functionCall" in candidate["content"]["parts"][0]:
|
||||
_function_chunk = ChatCompletionToolCallFunctionChunk(
|
||||
name=candidate["content"]["parts"][0]["functionCall"]["name"],
|
||||
name=candidate["content"]["parts"][0]["functionCall"][
|
||||
"name"
|
||||
],
|
||||
arguments=json.dumps(
|
||||
candidate["content"]["parts"][0]["functionCall"]["args"]
|
||||
),
|
||||
|
@ -1433,10 +1441,12 @@ class ModelResponseIterator:
|
|||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
_candidates: Optional[List[Candidates]] = processed_chunk.get("candidates")
|
||||
gemini_chunk: Optional[Candidates] = None
|
||||
if _candidates and len(_candidates) > 0:
|
||||
gemini_chunk = _candidates[0]
|
||||
|
||||
gemini_chunk = processed_chunk["candidates"][0]
|
||||
|
||||
if "content" in gemini_chunk:
|
||||
if gemini_chunk and "content" in gemini_chunk:
|
||||
if "text" in gemini_chunk["content"]["parts"][0]:
|
||||
text = gemini_chunk["content"]["parts"][0]["text"]
|
||||
elif "functionCall" in gemini_chunk["content"]["parts"][0]:
|
||||
|
@ -1455,7 +1465,7 @@ class ModelResponseIterator:
|
|||
index=0,
|
||||
)
|
||||
|
||||
if "finishReason" in gemini_chunk:
|
||||
if gemini_chunk and "finishReason" in gemini_chunk:
|
||||
finish_reason = map_finish_reason(
|
||||
finish_reason=gemini_chunk["finishReason"]
|
||||
)
|
||||
|
@ -1533,6 +1543,7 @@ class ModelResponseIterator:
|
|||
)
|
||||
|
||||
def _common_chunk_parsing_logic(self, chunk: str) -> GenericStreamingChunk:
|
||||
try:
|
||||
chunk = chunk.replace("data:", "")
|
||||
if len(chunk) > 0:
|
||||
"""
|
||||
|
@ -1544,7 +1555,7 @@ class ModelResponseIterator:
|
|||
return self.handle_valid_json_chunk(chunk=chunk)
|
||||
elif self.chunk_type == "accumulated_json":
|
||||
return self.handle_accumulated_json_chunk(chunk=chunk)
|
||||
else:
|
||||
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
|
@ -1553,6 +1564,8 @@ class ModelResponseIterator:
|
|||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
|
|
|
@ -83,6 +83,7 @@ from .llms import (
|
|||
from .llms.AI21 import completion as ai21
|
||||
from .llms.anthropic.chat import AnthropicChatCompletion
|
||||
from .llms.anthropic.completion import AnthropicTextCompletion
|
||||
from .llms.azure_ai.chat.handler import AzureAIChatCompletion
|
||||
from .llms.azure_text import AzureTextCompletion
|
||||
from .llms.AzureOpenAI.audio_transcriptions import AzureAudioTranscription
|
||||
from .llms.AzureOpenAI.azure import AzureChatCompletion, _check_dynamic_azure_params
|
||||
|
@ -166,6 +167,7 @@ openai_text_completions = OpenAITextCompletion()
|
|||
openai_o1_chat_completions = OpenAIO1ChatCompletion()
|
||||
openai_audio_transcriptions = OpenAIAudioTranscription()
|
||||
databricks_chat_completions = DatabricksChatCompletion()
|
||||
azure_ai_chat_completions = AzureAIChatCompletion()
|
||||
anthropic_chat_completions = AnthropicChatCompletion()
|
||||
anthropic_text_completions = AnthropicTextCompletion()
|
||||
azure_chat_completions = AzureChatCompletion()
|
||||
|
@ -1177,7 +1179,7 @@ def completion(
|
|||
headers = headers or litellm.headers
|
||||
|
||||
## LOAD CONFIG - if set
|
||||
config = litellm.OpenAIConfig.get_config()
|
||||
config = litellm.AzureAIStudioConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
|
@ -1190,7 +1192,7 @@ def completion(
|
|||
|
||||
## COMPLETION CALL
|
||||
try:
|
||||
response = openai_chat_completions.completion(
|
||||
response = azure_ai_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
|
|
|
@ -3862,9 +3862,9 @@
|
|||
"supports_vision": true
|
||||
},
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
||||
"max_tokens": 8192,
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 8192,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000003,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "bedrock",
|
||||
|
@ -3906,9 +3906,9 @@
|
|||
"supports_vision": true
|
||||
},
|
||||
"us.anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
||||
"max_tokens": 8192,
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 8192,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000003,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "bedrock",
|
||||
|
@ -3939,9 +3939,9 @@
|
|||
"supports_vision": true
|
||||
},
|
||||
"eu.anthropic.claude-3-sonnet-20240229-v1:0": {
|
||||
"max_tokens": 8192,
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 8192,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000003,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "bedrock",
|
||||
|
@ -3950,9 +3950,9 @@
|
|||
"supports_vision": true
|
||||
},
|
||||
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
||||
"max_tokens": 8192,
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 8192,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000003,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "bedrock",
|
||||
|
@ -5593,6 +5593,11 @@
|
|||
"output_cost_per_token": 0.0000012,
|
||||
"litellm_provider": "fireworks_ai"
|
||||
},
|
||||
"fireworks-ai-default": {
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "fireworks_ai"
|
||||
},
|
||||
"fireworks-ai-embedding-up-to-150m": {
|
||||
"input_cost_per_token": 0.000000008,
|
||||
"output_cost_per_token": 0.000000,
|
||||
|
|
|
@ -31,6 +31,13 @@ model_list:
|
|||
- model_name: "anthropic/*"
|
||||
litellm_params:
|
||||
model: "anthropic/*"
|
||||
- model_name: "openai/*"
|
||||
litellm_params:
|
||||
model: "openai/*"
|
||||
- model_name: "fireworks_ai/*"
|
||||
litellm_params:
|
||||
model: "fireworks_ai/*"
|
||||
configurable_clientside_auth_params: ["api_base"]
|
||||
|
||||
|
||||
litellm_settings:
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
|
|||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
from litellm import Router, provider_list
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import *
|
||||
|
||||
|
@ -72,7 +73,41 @@ def check_complete_credentials(request_body: dict) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def is_request_body_safe(request_body: dict) -> bool:
|
||||
def _allow_model_level_clientside_configurable_parameters(
|
||||
model: str, param: str, llm_router: Optional[Router]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if model is allowed to use configurable client-side params
|
||||
- get matching model
|
||||
- check if 'clientside_configurable_parameters' is set for model
|
||||
-
|
||||
"""
|
||||
if llm_router is None:
|
||||
return False
|
||||
# check if model is set
|
||||
model_info = llm_router.get_model_group_info(model_group=model)
|
||||
if model_info is None:
|
||||
# check if wildcard model is set
|
||||
if model.split("/", 1)[0] in provider_list:
|
||||
model_info = llm_router.get_model_group_info(
|
||||
model_group=model.split("/", 1)[0]
|
||||
)
|
||||
|
||||
if model_info is None:
|
||||
return False
|
||||
|
||||
if model_info is None or model_info.configurable_clientside_auth_params is None:
|
||||
return False
|
||||
|
||||
if param in model_info.configurable_clientside_auth_params:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_request_body_safe(
|
||||
request_body: dict, general_settings: dict, llm_router: Optional[Router], model: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the request body is safe.
|
||||
|
||||
|
@ -88,7 +123,20 @@ def is_request_body_safe(request_body: dict) -> bool:
|
|||
request_body=request_body
|
||||
)
|
||||
):
|
||||
raise ValueError(f"BadRequest: {param} is not allowed in request body")
|
||||
if general_settings.get("allow_client_side_credentials") is True:
|
||||
return True
|
||||
elif (
|
||||
_allow_model_level_clientside_configurable_parameters(
|
||||
model=model, param=param, llm_router=llm_router
|
||||
)
|
||||
is True
|
||||
):
|
||||
return True
|
||||
raise ValueError(
|
||||
f"Rejected Request: {param} is not allowed in request body. "
|
||||
"Enable with `general_settings::allow_client_side_credentials` on proxy config.yaml. "
|
||||
"Relevant Issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
@ -110,13 +158,20 @@ async def pre_db_read_auth_checks(
|
|||
Raises:
|
||||
- HTTPException if request fails initial auth checks
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings, premium_user
|
||||
from litellm.proxy.proxy_server import general_settings, llm_router, premium_user
|
||||
|
||||
# Check 1. request size
|
||||
await check_if_request_size_is_safe(request=request)
|
||||
|
||||
# Check 2. Request body is safe
|
||||
is_request_body_safe(request_body=request_data)
|
||||
is_request_body_safe(
|
||||
request_body=request_data,
|
||||
general_settings=general_settings,
|
||||
llm_router=llm_router,
|
||||
model=request_data.get(
|
||||
"model", ""
|
||||
), # [TODO] use model passed in url as well (azure openai routes)
|
||||
)
|
||||
|
||||
# Check 3. Check if IP address is allowed
|
||||
is_valid_ip, passed_in_ip = _check_valid_ip(
|
||||
|
|
|
@ -66,7 +66,7 @@ async def route_request(
|
|||
"""
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
|
||||
if "api_key" in data:
|
||||
if "api_key" in data or "api_base" in data:
|
||||
return getattr(litellm, f"{route_type}")(**data)
|
||||
|
||||
elif "user_config" in data:
|
||||
|
|
|
@ -4,6 +4,8 @@ import secrets
|
|||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
|
||||
|
@ -105,6 +107,8 @@ def get_logging_payload(
|
|||
additional_usage_values = {}
|
||||
for k, v in usage.items():
|
||||
if k not in special_usage_fields:
|
||||
if isinstance(v, BaseModel):
|
||||
v = v.model_dump()
|
||||
additional_usage_values.update({k: v})
|
||||
clean_metadata["additional_usage_values"] = additional_usage_values
|
||||
|
||||
|
|
|
@ -14,7 +14,17 @@ from datetime import datetime, timedelta
|
|||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union, overload
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
get_args,
|
||||
overload,
|
||||
)
|
||||
|
||||
import backoff
|
||||
import httpx
|
||||
|
@ -222,19 +232,7 @@ class ProxyLogging:
|
|||
self.cache_control_check = _PROXY_CacheControlCheck()
|
||||
self.alerting: Optional[List] = None
|
||||
self.alerting_threshold: float = 300 # default to 5 min. threshold
|
||||
self.alert_types: List[AlertType] = [
|
||||
"llm_exceptions",
|
||||
"llm_too_slow",
|
||||
"llm_requests_hanging",
|
||||
"budget_alerts",
|
||||
"db_exceptions",
|
||||
"daily_reports",
|
||||
"spend_reports",
|
||||
"fallback_reports",
|
||||
"cooldown_deployment",
|
||||
"new_model_added",
|
||||
"outage_alerts",
|
||||
]
|
||||
self.alert_types: List[AlertType] = list(get_args(AlertType))
|
||||
self.alert_to_webhook_url: Optional[dict] = None
|
||||
self.slack_alerting_instance: SlackAlerting = SlackAlerting(
|
||||
alerting_threshold=self.alerting_threshold,
|
||||
|
|
|
@ -4335,11 +4335,28 @@ class Router:
|
|||
|
||||
total_tpm: Optional[int] = None
|
||||
total_rpm: Optional[int] = None
|
||||
configurable_clientside_auth_params: Optional[List[str]] = None
|
||||
|
||||
for model in self.model_list:
|
||||
if "model_name" in model and model["model_name"] == model_group:
|
||||
is_match = False
|
||||
if (
|
||||
"model_name" in model and model["model_name"] == model_group
|
||||
): # exact match
|
||||
is_match = True
|
||||
elif (
|
||||
"model_name" in model
|
||||
and model_group in self.provider_default_deployments
|
||||
): # wildcard model
|
||||
is_match = True
|
||||
|
||||
if not is_match:
|
||||
continue
|
||||
# model in model group found #
|
||||
litellm_params = LiteLLM_Params(**model["litellm_params"])
|
||||
# get configurable clientside auth params
|
||||
configurable_clientside_auth_params = (
|
||||
litellm_params.configurable_clientside_auth_params
|
||||
)
|
||||
# get model tpm
|
||||
_deployment_tpm: Optional[int] = None
|
||||
if _deployment_tpm is None:
|
||||
|
@ -4425,9 +4442,7 @@ class Router:
|
|||
> model_group_info.max_input_tokens
|
||||
)
|
||||
):
|
||||
model_group_info.max_input_tokens = model_info[
|
||||
"max_input_tokens"
|
||||
]
|
||||
model_group_info.max_input_tokens = model_info["max_input_tokens"]
|
||||
if (
|
||||
model_info.get("max_output_tokens", None) is not None
|
||||
and model_info["max_output_tokens"] is not None
|
||||
|
@ -4437,9 +4452,7 @@ class Router:
|
|||
> model_group_info.max_output_tokens
|
||||
)
|
||||
):
|
||||
model_group_info.max_output_tokens = model_info[
|
||||
"max_output_tokens"
|
||||
]
|
||||
model_group_info.max_output_tokens = model_info["max_output_tokens"]
|
||||
if model_info.get("input_cost_per_token", None) is not None and (
|
||||
model_group_info.input_cost_per_token is None
|
||||
or model_info["input_cost_per_token"]
|
||||
|
@ -4480,13 +4493,20 @@ class Router:
|
|||
"supported_openai_params"
|
||||
]
|
||||
|
||||
if model_group_info is not None:
|
||||
## UPDATE WITH TOTAL TPM/RPM FOR MODEL GROUP
|
||||
if total_tpm is not None and model_group_info is not None:
|
||||
if total_tpm is not None:
|
||||
model_group_info.tpm = total_tpm
|
||||
|
||||
if total_rpm is not None and model_group_info is not None:
|
||||
if total_rpm is not None:
|
||||
model_group_info.rpm = total_rpm
|
||||
|
||||
## UPDATE WITH CONFIGURABLE CLIENTSIDE AUTH PARAMS FOR MODEL GROUP
|
||||
if configurable_clientside_auth_params is not None:
|
||||
model_group_info.configurable_clientside_auth_params = (
|
||||
configurable_clientside_auth_params
|
||||
)
|
||||
|
||||
return model_group_info
|
||||
|
||||
def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:
|
||||
|
|
|
@ -141,9 +141,16 @@ def test_completion_azure_ai_command_r():
|
|||
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "")
|
||||
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "")
|
||||
|
||||
response: litellm.ModelResponse = completion(
|
||||
response = completion(
|
||||
model="azure_ai/command-r-plus",
|
||||
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is the meaning of life?"}
|
||||
],
|
||||
}
|
||||
],
|
||||
) # type: ignore
|
||||
|
||||
assert "azure_ai" in response.model
|
||||
|
|
|
@ -1257,14 +1257,31 @@ def test_completion_cost_databricks_embedding(model):
|
|||
cost = completion_cost(completion_response=resp)
|
||||
|
||||
|
||||
def test_completion_cost_fireworks_ai():
|
||||
from litellm.llms.fireworks_ai.cost_calculator import get_base_model_for_pricing
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, base_model",
|
||||
[
|
||||
("fireworks_ai/llama-v3p1-405b-instruct", "fireworks-ai-default"),
|
||||
("fireworks_ai/mixtral-8x7b-instruct", "fireworks-ai-moe-up-to-56b"),
|
||||
],
|
||||
)
|
||||
def test_get_model_params_fireworks_ai(model, base_model):
|
||||
pricing_model = get_base_model_for_pricing(model_name=model)
|
||||
assert base_model == pricing_model
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["fireworks_ai/llama-v3p1-405b-instruct", "fireworks_ai/mixtral-8x7b-instruct"],
|
||||
)
|
||||
def test_completion_cost_fireworks_ai(model):
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
|
||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
resp = litellm.completion(
|
||||
model="fireworks_ai/mixtral-8x7b-instruct", messages=messages
|
||||
) # works fine
|
||||
resp = litellm.completion(model=model, messages=messages) # works fine
|
||||
|
||||
print(resp)
|
||||
cost = completion_cost(completion_response=resp)
|
||||
|
|
|
@ -12,6 +12,7 @@ sys.path.insert(
|
|||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.auth_utils import is_request_body_safe
|
||||
from litellm.proxy.litellm_pre_call_utils import (
|
||||
_get_dynamic_logging_metadata,
|
||||
add_litellm_data_to_request,
|
||||
|
@ -291,3 +292,78 @@ def test_dynamic_logging_metadata_key_and_team_metadata(callback_vars):
|
|||
|
||||
for var in callbacks.callback_vars.values():
|
||||
assert "os.environ" not in var
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"allow_client_side_credentials, expect_error", [(True, False), (False, True)]
|
||||
)
|
||||
def test_is_request_body_safe_global_enabled(
|
||||
allow_client_side_credentials, expect_error
|
||||
):
|
||||
from litellm import Router
|
||||
|
||||
error_raised = False
|
||||
|
||||
llm_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
try:
|
||||
is_request_body_safe(
|
||||
request_body={"api_base": "hello-world"},
|
||||
general_settings={
|
||||
"allow_client_side_credentials": allow_client_side_credentials
|
||||
},
|
||||
llm_router=llm_router,
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
error_raised = True
|
||||
|
||||
assert expect_error == error_raised
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"allow_client_side_credentials, expect_error", [(True, False), (False, True)]
|
||||
)
|
||||
def test_is_request_body_safe_model_enabled(
|
||||
allow_client_side_credentials, expect_error
|
||||
):
|
||||
from litellm import Router
|
||||
|
||||
error_raised = False
|
||||
|
||||
llm_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "fireworks_ai/*",
|
||||
"litellm_params": {
|
||||
"model": "fireworks_ai/*",
|
||||
"api_key": os.getenv("FIREWORKS_API_KEY"),
|
||||
"configurable_clientside_auth_params": (
|
||||
["api_base"] if allow_client_side_credentials else []
|
||||
),
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
try:
|
||||
is_request_body_safe(
|
||||
request_body={"api_base": "hello-world"},
|
||||
general_settings={},
|
||||
llm_router=llm_router,
|
||||
model="fireworks_ai/my-new-model",
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
error_raised = True
|
||||
|
||||
assert expect_error == error_raised
|
||||
|
|
|
@ -283,7 +283,7 @@ class PromptFeedback(TypedDict):
|
|||
|
||||
|
||||
class GenerateContentResponseBody(TypedDict, total=False):
|
||||
candidates: Required[List[Candidates]]
|
||||
candidates: List[Candidates]
|
||||
promptFeedback: PromptFeedback
|
||||
usageMetadata: Required[UsageMetadata]
|
||||
|
||||
|
|
|
@ -139,6 +139,7 @@ class GenericLiteLLMParams(BaseModel):
|
|||
)
|
||||
max_retries: Optional[int] = None
|
||||
organization: Optional[str] = None # for openai orgs
|
||||
configurable_clientside_auth_params: Optional[List[str]] = None
|
||||
## UNIFIED PROJECT/REGION ##
|
||||
region_name: Optional[str] = None
|
||||
## VERTEX AI ##
|
||||
|
@ -310,6 +311,9 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
|
|||
stream_timeout: Optional[Union[float, str]]
|
||||
max_retries: Optional[int]
|
||||
organization: Optional[Union[List, str]] # for openai orgs
|
||||
configurable_clientside_auth_params: Optional[
|
||||
List[str]
|
||||
] # for allowing api base switching on finetuned models
|
||||
## DROP PARAMS ##
|
||||
drop_params: Optional[bool]
|
||||
## UNIFIED PROJECT/REGION ##
|
||||
|
@ -487,6 +491,7 @@ class ModelGroupInfo(BaseModel):
|
|||
supports_vision: bool = Field(default=False)
|
||||
supports_function_calling: bool = Field(default=False)
|
||||
supported_openai_params: Optional[List[str]] = Field(default=[])
|
||||
configurable_clientside_auth_params: Optional[List[str]] = None
|
||||
|
||||
|
||||
class AssistantsTypedDict(TypedDict):
|
||||
|
|
|
@ -1196,6 +1196,7 @@ all_litellm_params = [
|
|||
"client_id",
|
||||
"client_secret",
|
||||
"user_continue_message",
|
||||
"configurable_clientside_auth_params",
|
||||
]
|
||||
|
||||
|
||||
|
@ -1323,7 +1324,7 @@ class StandardLoggingPayload(TypedDict):
|
|||
metadata: StandardLoggingMetadata
|
||||
cache_hit: Optional[bool]
|
||||
cache_key: Optional[str]
|
||||
saved_cache_cost: Optional[float]
|
||||
saved_cache_cost: float
|
||||
request_tags: list
|
||||
end_user: Optional[str]
|
||||
requester_ip_address: Optional[str]
|
||||
|
|
|
@ -3862,9 +3862,9 @@
|
|||
"supports_vision": true
|
||||
},
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
||||
"max_tokens": 8192,
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 8192,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000003,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "bedrock",
|
||||
|
@ -3906,9 +3906,9 @@
|
|||
"supports_vision": true
|
||||
},
|
||||
"us.anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
||||
"max_tokens": 8192,
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 8192,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000003,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "bedrock",
|
||||
|
@ -3939,9 +3939,9 @@
|
|||
"supports_vision": true
|
||||
},
|
||||
"eu.anthropic.claude-3-sonnet-20240229-v1:0": {
|
||||
"max_tokens": 8192,
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 8192,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000003,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "bedrock",
|
||||
|
@ -3950,9 +3950,9 @@
|
|||
"supports_vision": true
|
||||
},
|
||||
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
||||
"max_tokens": 8192,
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 8192,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000003,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "bedrock",
|
||||
|
@ -5593,6 +5593,11 @@
|
|||
"output_cost_per_token": 0.0000012,
|
||||
"litellm_provider": "fireworks_ai"
|
||||
},
|
||||
"fireworks-ai-default": {
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "fireworks_ai"
|
||||
},
|
||||
"fireworks-ai-embedding-up-to-150m": {
|
||||
"input_cost_per_token": 0.000000008,
|
||||
"output_cost_per_token": 0.000000,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue