forked from phoenix/litellm-mirror
* LiteLLM Minor Fixes & Improvements (09/23/2024) (#5842) * feat(auth_utils.py): enable admin to allow client-side credentials to be passed Makes it easier for devs to experiment with finetuned fireworks ai models * feat(router.py): allow setting configurable_clientside_auth_params for a model Closes https://github.com/BerriAI/litellm/issues/5843 * build(model_prices_and_context_window.json): fix anthropic claude-3-5-sonnet max output token limit Fixes https://github.com/BerriAI/litellm/issues/5850 * fix(azure_ai/): support content list for azure ai Fixes https://github.com/BerriAI/litellm/issues/4237 * fix(litellm_logging.py): always set saved_cache_cost Set to 0 by default * fix(fireworks_ai/cost_calculator.py): add fireworks ai default pricing handles calling 405b+ size models * fix(slack_alerting.py): fix error alerting for failed spend tracking Fixes regression with slack alerting error monitoring * fix(vertex_and_google_ai_studio_gemini.py): handle gemini no candidates in streaming chunk error * docs(bedrock.md): add llama3-1 models * test: fix tests * fix(azure_ai/chat): fix transformation for azure ai calls
This commit is contained in:
parent
4df9aca45e
commit
d37c8b5c6b
25 changed files with 611 additions and 294 deletions
|
@ -987,6 +987,9 @@ Here's an example of using a bedrock model with LiteLLM. For a complete list, re
|
||||||
| Anthropic Claude-V2.1 | `completion(model='bedrock/anthropic.claude-v2:1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
| Anthropic Claude-V2.1 | `completion(model='bedrock/anthropic.claude-v2:1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||||
| Anthropic Claude-V2 | `completion(model='bedrock/anthropic.claude-v2', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
| Anthropic Claude-V2 | `completion(model='bedrock/anthropic.claude-v2', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||||
| Anthropic Claude-Instant V1 | `completion(model='bedrock/anthropic.claude-instant-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
| Anthropic Claude-Instant V1 | `completion(model='bedrock/anthropic.claude-instant-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||||
|
| Meta llama3-1-405b | `completion(model='bedrock/meta.llama3-1-405b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||||
|
| Meta llama3-1-70b | `completion(model='bedrock/meta.llama3-1-70b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||||
|
| Meta llama3-1-8b | `completion(model='bedrock/meta.llama3-1-8b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||||
| Meta llama3-70b | `completion(model='bedrock/meta.llama3-70b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
| Meta llama3-70b | `completion(model='bedrock/meta.llama3-70b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||||
| Meta llama3-8b | `completion(model='bedrock/meta.llama3-8b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
| Meta llama3-8b | `completion(model='bedrock/meta.llama3-8b-instruct-v1:0', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']` |
|
||||||
| Amazon Titan Lite | `completion(model='bedrock/amazon.titan-text-lite-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
| Amazon Titan Lite | `completion(model='bedrock/amazon.titan-text-lite-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
|
||||||
|
|
|
@ -963,8 +963,8 @@ from .llms.OpenAI.openai import (
|
||||||
MistralEmbeddingConfig,
|
MistralEmbeddingConfig,
|
||||||
DeepInfraConfig,
|
DeepInfraConfig,
|
||||||
GroqConfig,
|
GroqConfig,
|
||||||
AzureAIStudioConfig,
|
|
||||||
)
|
)
|
||||||
|
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
||||||
from .llms.mistral.mistral_chat_transformation import MistralConfig
|
from .llms.mistral.mistral_chat_transformation import MistralConfig
|
||||||
from .llms.OpenAI.chat.o1_transformation import (
|
from .llms.OpenAI.chat.o1_transformation import (
|
||||||
OpenAIO1Config,
|
OpenAIO1Config,
|
||||||
|
|
|
@ -10,7 +10,7 @@ import traceback
|
||||||
from datetime import datetime as dt
|
from datetime import datetime as dt
|
||||||
from datetime import timedelta, timezone
|
from datetime import timedelta, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Set, TypedDict, Union
|
from typing import Any, Dict, List, Literal, Optional, Set, TypedDict, Union, get_args
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import dotenv
|
import dotenv
|
||||||
|
@ -57,20 +57,7 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
float
|
float
|
||||||
] = None, # threshold for slow / hanging llm responses (in seconds)
|
] = None, # threshold for slow / hanging llm responses (in seconds)
|
||||||
alerting: Optional[List] = [],
|
alerting: Optional[List] = [],
|
||||||
alert_types: List[AlertType] = [
|
alert_types: List[AlertType] = list(get_args(AlertType)),
|
||||||
"llm_exceptions",
|
|
||||||
"llm_too_slow",
|
|
||||||
"llm_requests_hanging",
|
|
||||||
"budget_alerts",
|
|
||||||
"db_exceptions",
|
|
||||||
"daily_reports",
|
|
||||||
"spend_reports",
|
|
||||||
"fallback_reports",
|
|
||||||
"cooldown_deployment",
|
|
||||||
"new_model_added",
|
|
||||||
"outage_alerts",
|
|
||||||
"failed_tracking_spend",
|
|
||||||
],
|
|
||||||
alert_to_webhook_url: Optional[
|
alert_to_webhook_url: Optional[
|
||||||
Dict[AlertType, Union[List[str], str]]
|
Dict[AlertType, Union[List[str], str]]
|
||||||
] = None, # if user wants to separate alerts to diff channels
|
] = None, # if user wants to separate alerts to diff channels
|
||||||
|
@ -613,7 +600,7 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
await self.send_alert(
|
await self.send_alert(
|
||||||
message=message,
|
message=message,
|
||||||
level="High",
|
level="High",
|
||||||
alert_type="budget_alerts",
|
alert_type="failed_tracking_spend",
|
||||||
alerting_metadata={},
|
alerting_metadata={},
|
||||||
)
|
)
|
||||||
await _cache.async_set_cache(
|
await _cache.async_set_cache(
|
||||||
|
|
|
@ -2498,14 +2498,17 @@ def get_standard_logging_object_payload(
|
||||||
else:
|
else:
|
||||||
cache_key = None
|
cache_key = None
|
||||||
|
|
||||||
saved_cache_cost: Optional[float] = None
|
saved_cache_cost: float = 0.0
|
||||||
if cache_hit is True:
|
if cache_hit is True:
|
||||||
|
|
||||||
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
|
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
|
||||||
|
|
||||||
saved_cache_cost = logging_obj._response_cost_calculator(
|
saved_cache_cost = (
|
||||||
|
logging_obj._response_cost_calculator(
|
||||||
result=init_response_obj, cache_hit=False # type: ignore
|
result=init_response_obj, cache_hit=False # type: ignore
|
||||||
)
|
)
|
||||||
|
or 0.0
|
||||||
|
)
|
||||||
|
|
||||||
## Get model cost information ##
|
## Get model cost information ##
|
||||||
base_model = _get_base_model_from_metadata(model_call_details=kwargs)
|
base_model = _get_base_model_from_metadata(model_call_details=kwargs)
|
||||||
|
|
|
@ -103,25 +103,6 @@ class MistralEmbeddingConfig:
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
class AzureAIStudioConfig:
|
|
||||||
def get_required_params(self) -> List[ProviderField]:
|
|
||||||
"""For a given provider, return it's required fields with a description"""
|
|
||||||
return [
|
|
||||||
ProviderField(
|
|
||||||
field_name="api_key",
|
|
||||||
field_type="string",
|
|
||||||
field_description="Your Azure AI Studio API Key.",
|
|
||||||
field_value="zEJ...",
|
|
||||||
),
|
|
||||||
ProviderField(
|
|
||||||
field_name="api_base",
|
|
||||||
field_type="string",
|
|
||||||
field_description="Your Azure AI Studio API Base.",
|
|
||||||
field_value="https://Mistral-serverless.",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class DeepInfraConfig:
|
class DeepInfraConfig:
|
||||||
"""
|
"""
|
||||||
Reference: https://deepinfra.com/docs/advanced/openai_api
|
Reference: https://deepinfra.com/docs/advanced/openai_api
|
||||||
|
|
59
litellm/llms/azure_ai/chat/handler.py
Normal file
59
litellm/llms/azure_ai/chat/handler.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
|
from httpx._config import Timeout
|
||||||
|
|
||||||
|
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
|
||||||
|
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
|
||||||
|
from litellm.types.utils import ModelResponse
|
||||||
|
from litellm.utils import CustomStreamWrapper
|
||||||
|
|
||||||
|
from .transformation import AzureAIStudioConfig
|
||||||
|
|
||||||
|
|
||||||
|
class AzureAIChatCompletion(OpenAIChatCompletion):
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
timeout: Union[float, Timeout],
|
||||||
|
optional_params: dict,
|
||||||
|
logging_obj: Any,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
messages: Optional[list] = None,
|
||||||
|
print_verbose: Optional[Callable[..., Any]] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
acompletion: bool = False,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers: Optional[dict] = None,
|
||||||
|
custom_prompt_dict: dict = {},
|
||||||
|
client=None,
|
||||||
|
organization: Optional[str] = None,
|
||||||
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
drop_params: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
transformed_messages = AzureAIStudioConfig()._transform_messages(
|
||||||
|
messages=messages # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().completion(
|
||||||
|
model_response,
|
||||||
|
timeout,
|
||||||
|
optional_params,
|
||||||
|
logging_obj,
|
||||||
|
model,
|
||||||
|
transformed_messages,
|
||||||
|
print_verbose,
|
||||||
|
api_key,
|
||||||
|
api_base,
|
||||||
|
acompletion,
|
||||||
|
litellm_params,
|
||||||
|
logger_fn,
|
||||||
|
headers,
|
||||||
|
custom_prompt_dict,
|
||||||
|
client,
|
||||||
|
organization,
|
||||||
|
custom_llm_provider,
|
||||||
|
drop_params,
|
||||||
|
)
|
31
litellm/llms/azure_ai/chat/transformation.py
Normal file
31
litellm/llms/azure_ai/chat/transformation.py
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from litellm.llms.OpenAI.openai import OpenAIConfig
|
||||||
|
from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import ProviderField
|
||||||
|
|
||||||
|
|
||||||
|
class AzureAIStudioConfig(OpenAIConfig):
|
||||||
|
def get_required_params(self) -> List[ProviderField]:
|
||||||
|
"""For a given provider, return it's required fields with a description"""
|
||||||
|
return [
|
||||||
|
ProviderField(
|
||||||
|
field_name="api_key",
|
||||||
|
field_type="string",
|
||||||
|
field_description="Your Azure AI Studio API Key.",
|
||||||
|
field_value="zEJ...",
|
||||||
|
),
|
||||||
|
ProviderField(
|
||||||
|
field_name="api_base",
|
||||||
|
field_type="string",
|
||||||
|
field_description="Your Azure AI Studio API Base.",
|
||||||
|
field_value="https://Mistral-serverless.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _transform_messages(self, messages: List[AllMessageValues]) -> List:
|
||||||
|
for message in messages:
|
||||||
|
message = convert_content_list_to_str(message=message)
|
||||||
|
|
||||||
|
return messages
|
|
@ -10,7 +10,7 @@ from litellm.utils import get_model_info
|
||||||
|
|
||||||
# Extract the number of billion parameters from the model name
|
# Extract the number of billion parameters from the model name
|
||||||
# only used for together_computer LLMs
|
# only used for together_computer LLMs
|
||||||
def get_model_params_and_category(model_name: str) -> str:
|
def get_base_model_for_pricing(model_name: str) -> str:
|
||||||
"""
|
"""
|
||||||
Helper function for calculating together ai pricing.
|
Helper function for calculating together ai pricing.
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ def get_model_params_and_category(model_name: str) -> str:
|
||||||
return "fireworks-ai-16b-80b"
|
return "fireworks-ai-16b-80b"
|
||||||
|
|
||||||
# If no matches, return the original model_name
|
# If no matches, return the original model_name
|
||||||
return model_name
|
return "fireworks-ai-default"
|
||||||
|
|
||||||
|
|
||||||
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
|
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
|
||||||
|
@ -57,10 +57,16 @@ def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||||
"""
|
"""
|
||||||
base_model = get_model_params_and_category(model_name=model)
|
## check if model mapped, else use default pricing
|
||||||
|
try:
|
||||||
|
model_info = get_model_info(model=model, custom_llm_provider="fireworks_ai")
|
||||||
|
except Exception:
|
||||||
|
base_model = get_base_model_for_pricing(model_name=model)
|
||||||
|
|
||||||
## GET MODEL INFO
|
## GET MODEL INFO
|
||||||
model_info = get_model_info(model=base_model, custom_llm_provider="fireworks_ai")
|
model_info = get_model_info(
|
||||||
|
model=base_model, custom_llm_provider="fireworks_ai"
|
||||||
|
)
|
||||||
|
|
||||||
## CALCULATE INPUT COST
|
## CALCULATE INPUT COST
|
||||||
|
|
||||||
|
|
32
litellm/llms/prompt_templates/common_utils.py
Normal file
32
litellm/llms/prompt_templates/common_utils.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
"""
|
||||||
|
Common utility functions used for translating messages across providers
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
|
||||||
|
|
||||||
|
def convert_content_list_to_str(message: AllMessageValues) -> AllMessageValues:
|
||||||
|
"""
|
||||||
|
- handles scenario where content is list and not string
|
||||||
|
- content list is just text, and no images
|
||||||
|
- if image passed in, then just return as is (user-intended)
|
||||||
|
|
||||||
|
Motivation: mistral api + azure ai don't support content as a list
|
||||||
|
"""
|
||||||
|
texts = ""
|
||||||
|
message_content = message.get("content")
|
||||||
|
if message_content:
|
||||||
|
if message_content is not None and isinstance(message_content, list):
|
||||||
|
for c in message_content:
|
||||||
|
text_content = c.get("text")
|
||||||
|
if text_content:
|
||||||
|
texts += text_content
|
||||||
|
elif message_content is not None and isinstance(message_content, str):
|
||||||
|
texts = message_content
|
||||||
|
|
||||||
|
if texts:
|
||||||
|
message["content"] = texts
|
||||||
|
|
||||||
|
return message
|
|
@ -49,6 +49,7 @@ from litellm.types.llms.openai import (
|
||||||
ChatCompletionUsageBlock,
|
ChatCompletionUsageBlock,
|
||||||
)
|
)
|
||||||
from litellm.types.llms.vertex_ai import (
|
from litellm.types.llms.vertex_ai import (
|
||||||
|
Candidates,
|
||||||
ContentType,
|
ContentType,
|
||||||
FunctionCallingConfig,
|
FunctionCallingConfig,
|
||||||
FunctionDeclaration,
|
FunctionDeclaration,
|
||||||
|
@ -187,7 +188,11 @@ class VertexAIConfig:
|
||||||
optional_params["stop_sequences"] = value
|
optional_params["stop_sequences"] = value
|
||||||
if param == "max_tokens" or param == "max_completion_tokens":
|
if param == "max_tokens" or param == "max_completion_tokens":
|
||||||
optional_params["max_output_tokens"] = value
|
optional_params["max_output_tokens"] = value
|
||||||
if param == "response_format" and value["type"] == "json_object":
|
if (
|
||||||
|
param == "response_format"
|
||||||
|
and isinstance(value, dict)
|
||||||
|
and value["type"] == "json_object"
|
||||||
|
):
|
||||||
optional_params["response_mime_type"] = "application/json"
|
optional_params["response_mime_type"] = "application/json"
|
||||||
if param == "frequency_penalty":
|
if param == "frequency_penalty":
|
||||||
optional_params["frequency_penalty"] = value
|
optional_params["frequency_penalty"] = value
|
||||||
|
@ -900,14 +905,14 @@ class VertexLLM(VertexBase):
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
if len(completion_response["candidates"]) > 0:
|
_candidates = completion_response.get("candidates")
|
||||||
|
if _candidates and len(_candidates) > 0:
|
||||||
content_policy_violations = (
|
content_policy_violations = (
|
||||||
VertexGeminiConfig().get_flagged_finish_reasons()
|
VertexGeminiConfig().get_flagged_finish_reasons()
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
"finishReason" in completion_response["candidates"][0]
|
"finishReason" in _candidates[0]
|
||||||
and completion_response["candidates"][0]["finishReason"]
|
and _candidates[0]["finishReason"] in content_policy_violations.keys()
|
||||||
in content_policy_violations.keys()
|
|
||||||
):
|
):
|
||||||
## CONTENT POLICY VIOLATION ERROR
|
## CONTENT POLICY VIOLATION ERROR
|
||||||
model_response.choices[0].finish_reason = "content_filter"
|
model_response.choices[0].finish_reason = "content_filter"
|
||||||
|
@ -956,12 +961,13 @@ class VertexLLM(VertexBase):
|
||||||
content_str = ""
|
content_str = ""
|
||||||
tools: List[ChatCompletionToolCallChunk] = []
|
tools: List[ChatCompletionToolCallChunk] = []
|
||||||
functions: Optional[ChatCompletionToolCallFunctionChunk] = None
|
functions: Optional[ChatCompletionToolCallFunctionChunk] = None
|
||||||
for idx, candidate in enumerate(completion_response["candidates"]):
|
if _candidates:
|
||||||
|
for idx, candidate in enumerate(_candidates):
|
||||||
if "content" not in candidate:
|
if "content" not in candidate:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if "groundingMetadata" in candidate:
|
if "groundingMetadata" in candidate:
|
||||||
grounding_metadata.append(candidate["groundingMetadata"])
|
grounding_metadata.append(candidate["groundingMetadata"]) # type: ignore
|
||||||
|
|
||||||
if "safetyRatings" in candidate:
|
if "safetyRatings" in candidate:
|
||||||
safety_ratings.append(candidate["safetyRatings"])
|
safety_ratings.append(candidate["safetyRatings"])
|
||||||
|
@ -973,7 +979,9 @@ class VertexLLM(VertexBase):
|
||||||
|
|
||||||
if "functionCall" in candidate["content"]["parts"][0]:
|
if "functionCall" in candidate["content"]["parts"][0]:
|
||||||
_function_chunk = ChatCompletionToolCallFunctionChunk(
|
_function_chunk = ChatCompletionToolCallFunctionChunk(
|
||||||
name=candidate["content"]["parts"][0]["functionCall"]["name"],
|
name=candidate["content"]["parts"][0]["functionCall"][
|
||||||
|
"name"
|
||||||
|
],
|
||||||
arguments=json.dumps(
|
arguments=json.dumps(
|
||||||
candidate["content"]["parts"][0]["functionCall"]["args"]
|
candidate["content"]["parts"][0]["functionCall"]["args"]
|
||||||
),
|
),
|
||||||
|
@ -1433,10 +1441,12 @@ class ModelResponseIterator:
|
||||||
is_finished = False
|
is_finished = False
|
||||||
finish_reason = ""
|
finish_reason = ""
|
||||||
usage: Optional[ChatCompletionUsageBlock] = None
|
usage: Optional[ChatCompletionUsageBlock] = None
|
||||||
|
_candidates: Optional[List[Candidates]] = processed_chunk.get("candidates")
|
||||||
|
gemini_chunk: Optional[Candidates] = None
|
||||||
|
if _candidates and len(_candidates) > 0:
|
||||||
|
gemini_chunk = _candidates[0]
|
||||||
|
|
||||||
gemini_chunk = processed_chunk["candidates"][0]
|
if gemini_chunk and "content" in gemini_chunk:
|
||||||
|
|
||||||
if "content" in gemini_chunk:
|
|
||||||
if "text" in gemini_chunk["content"]["parts"][0]:
|
if "text" in gemini_chunk["content"]["parts"][0]:
|
||||||
text = gemini_chunk["content"]["parts"][0]["text"]
|
text = gemini_chunk["content"]["parts"][0]["text"]
|
||||||
elif "functionCall" in gemini_chunk["content"]["parts"][0]:
|
elif "functionCall" in gemini_chunk["content"]["parts"][0]:
|
||||||
|
@ -1455,7 +1465,7 @@ class ModelResponseIterator:
|
||||||
index=0,
|
index=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "finishReason" in gemini_chunk:
|
if gemini_chunk and "finishReason" in gemini_chunk:
|
||||||
finish_reason = map_finish_reason(
|
finish_reason = map_finish_reason(
|
||||||
finish_reason=gemini_chunk["finishReason"]
|
finish_reason=gemini_chunk["finishReason"]
|
||||||
)
|
)
|
||||||
|
@ -1533,6 +1543,7 @@ class ModelResponseIterator:
|
||||||
)
|
)
|
||||||
|
|
||||||
def _common_chunk_parsing_logic(self, chunk: str) -> GenericStreamingChunk:
|
def _common_chunk_parsing_logic(self, chunk: str) -> GenericStreamingChunk:
|
||||||
|
try:
|
||||||
chunk = chunk.replace("data:", "")
|
chunk = chunk.replace("data:", "")
|
||||||
if len(chunk) > 0:
|
if len(chunk) > 0:
|
||||||
"""
|
"""
|
||||||
|
@ -1544,7 +1555,7 @@ class ModelResponseIterator:
|
||||||
return self.handle_valid_json_chunk(chunk=chunk)
|
return self.handle_valid_json_chunk(chunk=chunk)
|
||||||
elif self.chunk_type == "accumulated_json":
|
elif self.chunk_type == "accumulated_json":
|
||||||
return self.handle_accumulated_json_chunk(chunk=chunk)
|
return self.handle_accumulated_json_chunk(chunk=chunk)
|
||||||
else:
|
|
||||||
return GenericStreamingChunk(
|
return GenericStreamingChunk(
|
||||||
text="",
|
text="",
|
||||||
is_finished=False,
|
is_finished=False,
|
||||||
|
@ -1553,6 +1564,8 @@ class ModelResponseIterator:
|
||||||
index=0,
|
index=0,
|
||||||
tool_use=None,
|
tool_use=None,
|
||||||
)
|
)
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -83,6 +83,7 @@ from .llms import (
|
||||||
from .llms.AI21 import completion as ai21
|
from .llms.AI21 import completion as ai21
|
||||||
from .llms.anthropic.chat import AnthropicChatCompletion
|
from .llms.anthropic.chat import AnthropicChatCompletion
|
||||||
from .llms.anthropic.completion import AnthropicTextCompletion
|
from .llms.anthropic.completion import AnthropicTextCompletion
|
||||||
|
from .llms.azure_ai.chat.handler import AzureAIChatCompletion
|
||||||
from .llms.azure_text import AzureTextCompletion
|
from .llms.azure_text import AzureTextCompletion
|
||||||
from .llms.AzureOpenAI.audio_transcriptions import AzureAudioTranscription
|
from .llms.AzureOpenAI.audio_transcriptions import AzureAudioTranscription
|
||||||
from .llms.AzureOpenAI.azure import AzureChatCompletion, _check_dynamic_azure_params
|
from .llms.AzureOpenAI.azure import AzureChatCompletion, _check_dynamic_azure_params
|
||||||
|
@ -166,6 +167,7 @@ openai_text_completions = OpenAITextCompletion()
|
||||||
openai_o1_chat_completions = OpenAIO1ChatCompletion()
|
openai_o1_chat_completions = OpenAIO1ChatCompletion()
|
||||||
openai_audio_transcriptions = OpenAIAudioTranscription()
|
openai_audio_transcriptions = OpenAIAudioTranscription()
|
||||||
databricks_chat_completions = DatabricksChatCompletion()
|
databricks_chat_completions = DatabricksChatCompletion()
|
||||||
|
azure_ai_chat_completions = AzureAIChatCompletion()
|
||||||
anthropic_chat_completions = AnthropicChatCompletion()
|
anthropic_chat_completions = AnthropicChatCompletion()
|
||||||
anthropic_text_completions = AnthropicTextCompletion()
|
anthropic_text_completions = AnthropicTextCompletion()
|
||||||
azure_chat_completions = AzureChatCompletion()
|
azure_chat_completions = AzureChatCompletion()
|
||||||
|
@ -1177,7 +1179,7 @@ def completion(
|
||||||
headers = headers or litellm.headers
|
headers = headers or litellm.headers
|
||||||
|
|
||||||
## LOAD CONFIG - if set
|
## LOAD CONFIG - if set
|
||||||
config = litellm.OpenAIConfig.get_config()
|
config = litellm.AzureAIStudioConfig.get_config()
|
||||||
for k, v in config.items():
|
for k, v in config.items():
|
||||||
if (
|
if (
|
||||||
k not in optional_params
|
k not in optional_params
|
||||||
|
@ -1190,7 +1192,7 @@ def completion(
|
||||||
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
try:
|
try:
|
||||||
response = openai_chat_completions.completion(
|
response = azure_ai_chat_completions.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
|
|
@ -3862,9 +3862,9 @@
|
||||||
"supports_vision": true
|
"supports_vision": true
|
||||||
},
|
},
|
||||||
"anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
"anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
||||||
"max_tokens": 8192,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 200000,
|
"max_input_tokens": 200000,
|
||||||
"max_output_tokens": 8192,
|
"max_output_tokens": 4096,
|
||||||
"input_cost_per_token": 0.000003,
|
"input_cost_per_token": 0.000003,
|
||||||
"output_cost_per_token": 0.000015,
|
"output_cost_per_token": 0.000015,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
|
@ -3906,9 +3906,9 @@
|
||||||
"supports_vision": true
|
"supports_vision": true
|
||||||
},
|
},
|
||||||
"us.anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
"us.anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
||||||
"max_tokens": 8192,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 200000,
|
"max_input_tokens": 200000,
|
||||||
"max_output_tokens": 8192,
|
"max_output_tokens": 4096,
|
||||||
"input_cost_per_token": 0.000003,
|
"input_cost_per_token": 0.000003,
|
||||||
"output_cost_per_token": 0.000015,
|
"output_cost_per_token": 0.000015,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
|
@ -3939,9 +3939,9 @@
|
||||||
"supports_vision": true
|
"supports_vision": true
|
||||||
},
|
},
|
||||||
"eu.anthropic.claude-3-sonnet-20240229-v1:0": {
|
"eu.anthropic.claude-3-sonnet-20240229-v1:0": {
|
||||||
"max_tokens": 8192,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 200000,
|
"max_input_tokens": 200000,
|
||||||
"max_output_tokens": 8192,
|
"max_output_tokens": 4096,
|
||||||
"input_cost_per_token": 0.000003,
|
"input_cost_per_token": 0.000003,
|
||||||
"output_cost_per_token": 0.000015,
|
"output_cost_per_token": 0.000015,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
|
@ -3950,9 +3950,9 @@
|
||||||
"supports_vision": true
|
"supports_vision": true
|
||||||
},
|
},
|
||||||
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
||||||
"max_tokens": 8192,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 200000,
|
"max_input_tokens": 200000,
|
||||||
"max_output_tokens": 8192,
|
"max_output_tokens": 4096,
|
||||||
"input_cost_per_token": 0.000003,
|
"input_cost_per_token": 0.000003,
|
||||||
"output_cost_per_token": 0.000015,
|
"output_cost_per_token": 0.000015,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
|
@ -5593,6 +5593,11 @@
|
||||||
"output_cost_per_token": 0.0000012,
|
"output_cost_per_token": 0.0000012,
|
||||||
"litellm_provider": "fireworks_ai"
|
"litellm_provider": "fireworks_ai"
|
||||||
},
|
},
|
||||||
|
"fireworks-ai-default": {
|
||||||
|
"input_cost_per_token": 0.0,
|
||||||
|
"output_cost_per_token": 0.0,
|
||||||
|
"litellm_provider": "fireworks_ai"
|
||||||
|
},
|
||||||
"fireworks-ai-embedding-up-to-150m": {
|
"fireworks-ai-embedding-up-to-150m": {
|
||||||
"input_cost_per_token": 0.000000008,
|
"input_cost_per_token": 0.000000008,
|
||||||
"output_cost_per_token": 0.000000,
|
"output_cost_per_token": 0.000000,
|
||||||
|
|
|
@ -31,6 +31,13 @@ model_list:
|
||||||
- model_name: "anthropic/*"
|
- model_name: "anthropic/*"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: "anthropic/*"
|
model: "anthropic/*"
|
||||||
|
- model_name: "openai/*"
|
||||||
|
litellm_params:
|
||||||
|
model: "openai/*"
|
||||||
|
- model_name: "fireworks_ai/*"
|
||||||
|
litellm_params:
|
||||||
|
model: "fireworks_ai/*"
|
||||||
|
configurable_clientside_auth_params: ["api_base"]
|
||||||
|
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
|
|
|
@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from fastapi import HTTPException, Request, status
|
from fastapi import HTTPException, Request, status
|
||||||
|
|
||||||
|
from litellm import Router, provider_list
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
|
|
||||||
|
@ -72,7 +73,41 @@ def check_complete_credentials(request_body: dict) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_request_body_safe(request_body: dict) -> bool:
|
def _allow_model_level_clientside_configurable_parameters(
|
||||||
|
model: str, param: str, llm_router: Optional[Router]
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if model is allowed to use configurable client-side params
|
||||||
|
- get matching model
|
||||||
|
- check if 'clientside_configurable_parameters' is set for model
|
||||||
|
-
|
||||||
|
"""
|
||||||
|
if llm_router is None:
|
||||||
|
return False
|
||||||
|
# check if model is set
|
||||||
|
model_info = llm_router.get_model_group_info(model_group=model)
|
||||||
|
if model_info is None:
|
||||||
|
# check if wildcard model is set
|
||||||
|
if model.split("/", 1)[0] in provider_list:
|
||||||
|
model_info = llm_router.get_model_group_info(
|
||||||
|
model_group=model.split("/", 1)[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_info is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if model_info is None or model_info.configurable_clientside_auth_params is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if param in model_info.configurable_clientside_auth_params:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_request_body_safe(
|
||||||
|
request_body: dict, general_settings: dict, llm_router: Optional[Router], model: str
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the request body is safe.
|
Check if the request body is safe.
|
||||||
|
|
||||||
|
@ -88,7 +123,20 @@ def is_request_body_safe(request_body: dict) -> bool:
|
||||||
request_body=request_body
|
request_body=request_body
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
raise ValueError(f"BadRequest: {param} is not allowed in request body")
|
if general_settings.get("allow_client_side_credentials") is True:
|
||||||
|
return True
|
||||||
|
elif (
|
||||||
|
_allow_model_level_clientside_configurable_parameters(
|
||||||
|
model=model, param=param, llm_router=llm_router
|
||||||
|
)
|
||||||
|
is True
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
raise ValueError(
|
||||||
|
f"Rejected Request: {param} is not allowed in request body. "
|
||||||
|
"Enable with `general_settings::allow_client_side_credentials` on proxy config.yaml. "
|
||||||
|
"Relevant Issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997",
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -110,13 +158,20 @@ async def pre_db_read_auth_checks(
|
||||||
Raises:
|
Raises:
|
||||||
- HTTPException if request fails initial auth checks
|
- HTTPException if request fails initial auth checks
|
||||||
"""
|
"""
|
||||||
from litellm.proxy.proxy_server import general_settings, premium_user
|
from litellm.proxy.proxy_server import general_settings, llm_router, premium_user
|
||||||
|
|
||||||
# Check 1. request size
|
# Check 1. request size
|
||||||
await check_if_request_size_is_safe(request=request)
|
await check_if_request_size_is_safe(request=request)
|
||||||
|
|
||||||
# Check 2. Request body is safe
|
# Check 2. Request body is safe
|
||||||
is_request_body_safe(request_body=request_data)
|
is_request_body_safe(
|
||||||
|
request_body=request_data,
|
||||||
|
general_settings=general_settings,
|
||||||
|
llm_router=llm_router,
|
||||||
|
model=request_data.get(
|
||||||
|
"model", ""
|
||||||
|
), # [TODO] use model passed in url as well (azure openai routes)
|
||||||
|
)
|
||||||
|
|
||||||
# Check 3. Check if IP address is allowed
|
# Check 3. Check if IP address is allowed
|
||||||
is_valid_ip, passed_in_ip = _check_valid_ip(
|
is_valid_ip, passed_in_ip = _check_valid_ip(
|
||||||
|
|
|
@ -66,7 +66,7 @@ async def route_request(
|
||||||
"""
|
"""
|
||||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||||
|
|
||||||
if "api_key" in data:
|
if "api_key" in data or "api_base" in data:
|
||||||
return getattr(litellm, f"{route_type}")(**data)
|
return getattr(litellm, f"{route_type}")(**data)
|
||||||
|
|
||||||
elif "user_config" in data:
|
elif "user_config" in data:
|
||||||
|
|
|
@ -4,6 +4,8 @@ import secrets
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
|
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
|
||||||
|
@ -105,6 +107,8 @@ def get_logging_payload(
|
||||||
additional_usage_values = {}
|
additional_usage_values = {}
|
||||||
for k, v in usage.items():
|
for k, v in usage.items():
|
||||||
if k not in special_usage_fields:
|
if k not in special_usage_fields:
|
||||||
|
if isinstance(v, BaseModel):
|
||||||
|
v = v.model_dump()
|
||||||
additional_usage_values.update({k: v})
|
additional_usage_values.update({k: v})
|
||||||
clean_metadata["additional_usage_values"] = additional_usage_values
|
clean_metadata["additional_usage_values"] = additional_usage_values
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,17 @@ from datetime import datetime, timedelta
|
||||||
from email.mime.multipart import MIMEMultipart
|
from email.mime.multipart import MIMEMultipart
|
||||||
from email.mime.text import MIMEText
|
from email.mime.text import MIMEText
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union, overload
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
get_args,
|
||||||
|
overload,
|
||||||
|
)
|
||||||
|
|
||||||
import backoff
|
import backoff
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -222,19 +232,7 @@ class ProxyLogging:
|
||||||
self.cache_control_check = _PROXY_CacheControlCheck()
|
self.cache_control_check = _PROXY_CacheControlCheck()
|
||||||
self.alerting: Optional[List] = None
|
self.alerting: Optional[List] = None
|
||||||
self.alerting_threshold: float = 300 # default to 5 min. threshold
|
self.alerting_threshold: float = 300 # default to 5 min. threshold
|
||||||
self.alert_types: List[AlertType] = [
|
self.alert_types: List[AlertType] = list(get_args(AlertType))
|
||||||
"llm_exceptions",
|
|
||||||
"llm_too_slow",
|
|
||||||
"llm_requests_hanging",
|
|
||||||
"budget_alerts",
|
|
||||||
"db_exceptions",
|
|
||||||
"daily_reports",
|
|
||||||
"spend_reports",
|
|
||||||
"fallback_reports",
|
|
||||||
"cooldown_deployment",
|
|
||||||
"new_model_added",
|
|
||||||
"outage_alerts",
|
|
||||||
]
|
|
||||||
self.alert_to_webhook_url: Optional[dict] = None
|
self.alert_to_webhook_url: Optional[dict] = None
|
||||||
self.slack_alerting_instance: SlackAlerting = SlackAlerting(
|
self.slack_alerting_instance: SlackAlerting = SlackAlerting(
|
||||||
alerting_threshold=self.alerting_threshold,
|
alerting_threshold=self.alerting_threshold,
|
||||||
|
|
|
@ -4335,11 +4335,28 @@ class Router:
|
||||||
|
|
||||||
total_tpm: Optional[int] = None
|
total_tpm: Optional[int] = None
|
||||||
total_rpm: Optional[int] = None
|
total_rpm: Optional[int] = None
|
||||||
|
configurable_clientside_auth_params: Optional[List[str]] = None
|
||||||
|
|
||||||
for model in self.model_list:
|
for model in self.model_list:
|
||||||
if "model_name" in model and model["model_name"] == model_group:
|
is_match = False
|
||||||
|
if (
|
||||||
|
"model_name" in model and model["model_name"] == model_group
|
||||||
|
): # exact match
|
||||||
|
is_match = True
|
||||||
|
elif (
|
||||||
|
"model_name" in model
|
||||||
|
and model_group in self.provider_default_deployments
|
||||||
|
): # wildcard model
|
||||||
|
is_match = True
|
||||||
|
|
||||||
|
if not is_match:
|
||||||
|
continue
|
||||||
# model in model group found #
|
# model in model group found #
|
||||||
litellm_params = LiteLLM_Params(**model["litellm_params"])
|
litellm_params = LiteLLM_Params(**model["litellm_params"])
|
||||||
|
# get configurable clientside auth params
|
||||||
|
configurable_clientside_auth_params = (
|
||||||
|
litellm_params.configurable_clientside_auth_params
|
||||||
|
)
|
||||||
# get model tpm
|
# get model tpm
|
||||||
_deployment_tpm: Optional[int] = None
|
_deployment_tpm: Optional[int] = None
|
||||||
if _deployment_tpm is None:
|
if _deployment_tpm is None:
|
||||||
|
@ -4425,9 +4442,7 @@ class Router:
|
||||||
> model_group_info.max_input_tokens
|
> model_group_info.max_input_tokens
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
model_group_info.max_input_tokens = model_info[
|
model_group_info.max_input_tokens = model_info["max_input_tokens"]
|
||||||
"max_input_tokens"
|
|
||||||
]
|
|
||||||
if (
|
if (
|
||||||
model_info.get("max_output_tokens", None) is not None
|
model_info.get("max_output_tokens", None) is not None
|
||||||
and model_info["max_output_tokens"] is not None
|
and model_info["max_output_tokens"] is not None
|
||||||
|
@ -4437,9 +4452,7 @@ class Router:
|
||||||
> model_group_info.max_output_tokens
|
> model_group_info.max_output_tokens
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
model_group_info.max_output_tokens = model_info[
|
model_group_info.max_output_tokens = model_info["max_output_tokens"]
|
||||||
"max_output_tokens"
|
|
||||||
]
|
|
||||||
if model_info.get("input_cost_per_token", None) is not None and (
|
if model_info.get("input_cost_per_token", None) is not None and (
|
||||||
model_group_info.input_cost_per_token is None
|
model_group_info.input_cost_per_token is None
|
||||||
or model_info["input_cost_per_token"]
|
or model_info["input_cost_per_token"]
|
||||||
|
@ -4480,13 +4493,20 @@ class Router:
|
||||||
"supported_openai_params"
|
"supported_openai_params"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if model_group_info is not None:
|
||||||
## UPDATE WITH TOTAL TPM/RPM FOR MODEL GROUP
|
## UPDATE WITH TOTAL TPM/RPM FOR MODEL GROUP
|
||||||
if total_tpm is not None and model_group_info is not None:
|
if total_tpm is not None:
|
||||||
model_group_info.tpm = total_tpm
|
model_group_info.tpm = total_tpm
|
||||||
|
|
||||||
if total_rpm is not None and model_group_info is not None:
|
if total_rpm is not None:
|
||||||
model_group_info.rpm = total_rpm
|
model_group_info.rpm = total_rpm
|
||||||
|
|
||||||
|
## UPDATE WITH CONFIGURABLE CLIENTSIDE AUTH PARAMS FOR MODEL GROUP
|
||||||
|
if configurable_clientside_auth_params is not None:
|
||||||
|
model_group_info.configurable_clientside_auth_params = (
|
||||||
|
configurable_clientside_auth_params
|
||||||
|
)
|
||||||
|
|
||||||
return model_group_info
|
return model_group_info
|
||||||
|
|
||||||
def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:
|
def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:
|
||||||
|
|
|
@ -141,9 +141,16 @@ def test_completion_azure_ai_command_r():
|
||||||
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "")
|
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "")
|
||||||
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "")
|
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "")
|
||||||
|
|
||||||
response: litellm.ModelResponse = completion(
|
response = completion(
|
||||||
model="azure_ai/command-r-plus",
|
model="azure_ai/command-r-plus",
|
||||||
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is the meaning of life?"}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
assert "azure_ai" in response.model
|
assert "azure_ai" in response.model
|
||||||
|
|
|
@ -1257,14 +1257,31 @@ def test_completion_cost_databricks_embedding(model):
|
||||||
cost = completion_cost(completion_response=resp)
|
cost = completion_cost(completion_response=resp)
|
||||||
|
|
||||||
|
|
||||||
def test_completion_cost_fireworks_ai():
|
from litellm.llms.fireworks_ai.cost_calculator import get_base_model_for_pricing
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, base_model",
|
||||||
|
[
|
||||||
|
("fireworks_ai/llama-v3p1-405b-instruct", "fireworks-ai-default"),
|
||||||
|
("fireworks_ai/mixtral-8x7b-instruct", "fireworks-ai-moe-up-to-56b"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_model_params_fireworks_ai(model, base_model):
|
||||||
|
pricing_model = get_base_model_for_pricing(model_name=model)
|
||||||
|
assert base_model == pricing_model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
["fireworks_ai/llama-v3p1-405b-instruct", "fireworks_ai/mixtral-8x7b-instruct"],
|
||||||
|
)
|
||||||
|
def test_completion_cost_fireworks_ai(model):
|
||||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
resp = litellm.completion(
|
resp = litellm.completion(model=model, messages=messages) # works fine
|
||||||
model="fireworks_ai/mixtral-8x7b-instruct", messages=messages
|
|
||||||
) # works fine
|
|
||||||
|
|
||||||
print(resp)
|
print(resp)
|
||||||
cost = completion_cost(completion_response=resp)
|
cost = completion_cost(completion_response=resp)
|
||||||
|
|
|
@ -12,6 +12,7 @@ sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||||||
|
from litellm.proxy.auth.auth_utils import is_request_body_safe
|
||||||
from litellm.proxy.litellm_pre_call_utils import (
|
from litellm.proxy.litellm_pre_call_utils import (
|
||||||
_get_dynamic_logging_metadata,
|
_get_dynamic_logging_metadata,
|
||||||
add_litellm_data_to_request,
|
add_litellm_data_to_request,
|
||||||
|
@ -291,3 +292,78 @@ def test_dynamic_logging_metadata_key_and_team_metadata(callback_vars):
|
||||||
|
|
||||||
for var in callbacks.callback_vars.values():
|
for var in callbacks.callback_vars.values():
|
||||||
assert "os.environ" not in var
|
assert "os.environ" not in var
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"allow_client_side_credentials, expect_error", [(True, False), (False, True)]
|
||||||
|
)
|
||||||
|
def test_is_request_body_safe_global_enabled(
|
||||||
|
allow_client_side_credentials, expect_error
|
||||||
|
):
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
|
error_raised = False
|
||||||
|
|
||||||
|
llm_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
is_request_body_safe(
|
||||||
|
request_body={"api_base": "hello-world"},
|
||||||
|
general_settings={
|
||||||
|
"allow_client_side_credentials": allow_client_side_credentials
|
||||||
|
},
|
||||||
|
llm_router=llm_router,
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
error_raised = True
|
||||||
|
|
||||||
|
assert expect_error == error_raised
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"allow_client_side_credentials, expect_error", [(True, False), (False, True)]
|
||||||
|
)
|
||||||
|
def test_is_request_body_safe_model_enabled(
|
||||||
|
allow_client_side_credentials, expect_error
|
||||||
|
):
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
|
error_raised = False
|
||||||
|
|
||||||
|
llm_router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "fireworks_ai/*",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "fireworks_ai/*",
|
||||||
|
"api_key": os.getenv("FIREWORKS_API_KEY"),
|
||||||
|
"configurable_clientside_auth_params": (
|
||||||
|
["api_base"] if allow_client_side_credentials else []
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
is_request_body_safe(
|
||||||
|
request_body={"api_base": "hello-world"},
|
||||||
|
general_settings={},
|
||||||
|
llm_router=llm_router,
|
||||||
|
model="fireworks_ai/my-new-model",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
error_raised = True
|
||||||
|
|
||||||
|
assert expect_error == error_raised
|
||||||
|
|
|
@ -283,7 +283,7 @@ class PromptFeedback(TypedDict):
|
||||||
|
|
||||||
|
|
||||||
class GenerateContentResponseBody(TypedDict, total=False):
|
class GenerateContentResponseBody(TypedDict, total=False):
|
||||||
candidates: Required[List[Candidates]]
|
candidates: List[Candidates]
|
||||||
promptFeedback: PromptFeedback
|
promptFeedback: PromptFeedback
|
||||||
usageMetadata: Required[UsageMetadata]
|
usageMetadata: Required[UsageMetadata]
|
||||||
|
|
||||||
|
|
|
@ -139,6 +139,7 @@ class GenericLiteLLMParams(BaseModel):
|
||||||
)
|
)
|
||||||
max_retries: Optional[int] = None
|
max_retries: Optional[int] = None
|
||||||
organization: Optional[str] = None # for openai orgs
|
organization: Optional[str] = None # for openai orgs
|
||||||
|
configurable_clientside_auth_params: Optional[List[str]] = None
|
||||||
## UNIFIED PROJECT/REGION ##
|
## UNIFIED PROJECT/REGION ##
|
||||||
region_name: Optional[str] = None
|
region_name: Optional[str] = None
|
||||||
## VERTEX AI ##
|
## VERTEX AI ##
|
||||||
|
@ -310,6 +311,9 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
|
||||||
stream_timeout: Optional[Union[float, str]]
|
stream_timeout: Optional[Union[float, str]]
|
||||||
max_retries: Optional[int]
|
max_retries: Optional[int]
|
||||||
organization: Optional[Union[List, str]] # for openai orgs
|
organization: Optional[Union[List, str]] # for openai orgs
|
||||||
|
configurable_clientside_auth_params: Optional[
|
||||||
|
List[str]
|
||||||
|
] # for allowing api base switching on finetuned models
|
||||||
## DROP PARAMS ##
|
## DROP PARAMS ##
|
||||||
drop_params: Optional[bool]
|
drop_params: Optional[bool]
|
||||||
## UNIFIED PROJECT/REGION ##
|
## UNIFIED PROJECT/REGION ##
|
||||||
|
@ -487,6 +491,7 @@ class ModelGroupInfo(BaseModel):
|
||||||
supports_vision: bool = Field(default=False)
|
supports_vision: bool = Field(default=False)
|
||||||
supports_function_calling: bool = Field(default=False)
|
supports_function_calling: bool = Field(default=False)
|
||||||
supported_openai_params: Optional[List[str]] = Field(default=[])
|
supported_openai_params: Optional[List[str]] = Field(default=[])
|
||||||
|
configurable_clientside_auth_params: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
class AssistantsTypedDict(TypedDict):
|
class AssistantsTypedDict(TypedDict):
|
||||||
|
|
|
@ -1196,6 +1196,7 @@ all_litellm_params = [
|
||||||
"client_id",
|
"client_id",
|
||||||
"client_secret",
|
"client_secret",
|
||||||
"user_continue_message",
|
"user_continue_message",
|
||||||
|
"configurable_clientside_auth_params",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1323,7 +1324,7 @@ class StandardLoggingPayload(TypedDict):
|
||||||
metadata: StandardLoggingMetadata
|
metadata: StandardLoggingMetadata
|
||||||
cache_hit: Optional[bool]
|
cache_hit: Optional[bool]
|
||||||
cache_key: Optional[str]
|
cache_key: Optional[str]
|
||||||
saved_cache_cost: Optional[float]
|
saved_cache_cost: float
|
||||||
request_tags: list
|
request_tags: list
|
||||||
end_user: Optional[str]
|
end_user: Optional[str]
|
||||||
requester_ip_address: Optional[str]
|
requester_ip_address: Optional[str]
|
||||||
|
|
|
@ -3862,9 +3862,9 @@
|
||||||
"supports_vision": true
|
"supports_vision": true
|
||||||
},
|
},
|
||||||
"anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
"anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
||||||
"max_tokens": 8192,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 200000,
|
"max_input_tokens": 200000,
|
||||||
"max_output_tokens": 8192,
|
"max_output_tokens": 4096,
|
||||||
"input_cost_per_token": 0.000003,
|
"input_cost_per_token": 0.000003,
|
||||||
"output_cost_per_token": 0.000015,
|
"output_cost_per_token": 0.000015,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
|
@ -3906,9 +3906,9 @@
|
||||||
"supports_vision": true
|
"supports_vision": true
|
||||||
},
|
},
|
||||||
"us.anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
"us.anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
||||||
"max_tokens": 8192,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 200000,
|
"max_input_tokens": 200000,
|
||||||
"max_output_tokens": 8192,
|
"max_output_tokens": 4096,
|
||||||
"input_cost_per_token": 0.000003,
|
"input_cost_per_token": 0.000003,
|
||||||
"output_cost_per_token": 0.000015,
|
"output_cost_per_token": 0.000015,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
|
@ -3939,9 +3939,9 @@
|
||||||
"supports_vision": true
|
"supports_vision": true
|
||||||
},
|
},
|
||||||
"eu.anthropic.claude-3-sonnet-20240229-v1:0": {
|
"eu.anthropic.claude-3-sonnet-20240229-v1:0": {
|
||||||
"max_tokens": 8192,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 200000,
|
"max_input_tokens": 200000,
|
||||||
"max_output_tokens": 8192,
|
"max_output_tokens": 4096,
|
||||||
"input_cost_per_token": 0.000003,
|
"input_cost_per_token": 0.000003,
|
||||||
"output_cost_per_token": 0.000015,
|
"output_cost_per_token": 0.000015,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
|
@ -3950,9 +3950,9 @@
|
||||||
"supports_vision": true
|
"supports_vision": true
|
||||||
},
|
},
|
||||||
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
||||||
"max_tokens": 8192,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 200000,
|
"max_input_tokens": 200000,
|
||||||
"max_output_tokens": 8192,
|
"max_output_tokens": 4096,
|
||||||
"input_cost_per_token": 0.000003,
|
"input_cost_per_token": 0.000003,
|
||||||
"output_cost_per_token": 0.000015,
|
"output_cost_per_token": 0.000015,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
|
@ -5593,6 +5593,11 @@
|
||||||
"output_cost_per_token": 0.0000012,
|
"output_cost_per_token": 0.0000012,
|
||||||
"litellm_provider": "fireworks_ai"
|
"litellm_provider": "fireworks_ai"
|
||||||
},
|
},
|
||||||
|
"fireworks-ai-default": {
|
||||||
|
"input_cost_per_token": 0.0,
|
||||||
|
"output_cost_per_token": 0.0,
|
||||||
|
"litellm_provider": "fireworks_ai"
|
||||||
|
},
|
||||||
"fireworks-ai-embedding-up-to-150m": {
|
"fireworks-ai-embedding-up-to-150m": {
|
||||||
"input_cost_per_token": 0.000000008,
|
"input_cost_per_token": 0.000000008,
|
||||||
"output_cost_per_token": 0.000000,
|
"output_cost_per_token": 0.000000,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue