forked from phoenix/litellm-mirror
Merge branch 'main' into litellm_ui_fixes_6
This commit is contained in:
commit
0e709fdc21
26 changed files with 631 additions and 185 deletions
|
@ -226,6 +226,7 @@ curl 'http://0.0.0.0:4000/key/generate' \
|
|||
| [deepinfra](https://docs.litellm.ai/docs/providers/deepinfra) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [perplexity-ai](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [Groq AI](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [Deepseek](https://docs.litellm.ai/docs/providers/deepseek) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [anyscale](https://docs.litellm.ai/docs/providers/anyscale) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [IBM - watsonx.ai](https://docs.litellm.ai/docs/providers/watsonx) | ✅ | ✅ | ✅ | ✅ | ✅
|
||||
| [voyage ai](https://docs.litellm.ai/docs/providers/voyage) | | | | | ✅ |
|
||||
|
|
54
docs/my-website/docs/providers/deepseek.md
Normal file
54
docs/my-website/docs/providers/deepseek.md
Normal file
|
@ -0,0 +1,54 @@
|
|||
# Deepseek
|
||||
https://deepseek.com/
|
||||
|
||||
**We support ALL Deepseek models, just set `deepseek/` as a prefix when sending completion requests**
|
||||
|
||||
## API Key
|
||||
```python
|
||||
# env variable
|
||||
os.environ['DEEPSEEK_API_KEY']
|
||||
```
|
||||
|
||||
## Sample Usage
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
os.environ['DEEPSEEK_API_KEY'] = ""
|
||||
response = completion(
|
||||
model="deepseek/deepseek-chat",
|
||||
messages=[
|
||||
{"role": "user", "content": "hello from litellm"}
|
||||
],
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
## Sample Usage - Streaming
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
os.environ['DEEPSEEK_API_KEY'] = ""
|
||||
response = completion(
|
||||
model="deepseek/deepseek-chat",
|
||||
messages=[
|
||||
{"role": "user", "content": "hello from litellm"}
|
||||
],
|
||||
stream=True
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
```
|
||||
|
||||
|
||||
## Supported Models - ALL Deepseek Models Supported!
|
||||
We support ALL Deepseek models, just set `deepseek/` as a prefix when sending completion requests
|
||||
|
||||
| Model Name | Function Call |
|
||||
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| deepseek-chat | `completion(model="deepseek/deepseek-chat", messages)` |
|
||||
| deepseek-coder | `completion(model="deepseek/deepseek-chat", messages)` |
|
||||
|
||||
|
|
@ -468,7 +468,7 @@ asyncio.run(router_acompletion())
|
|||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="lowest-cost" label="Lowest Cost Routing">
|
||||
<TabItem value="lowest-cost" label="Lowest Cost Routing (Async)">
|
||||
|
||||
Picks a deployment based on the lowest cost
|
||||
|
||||
|
@ -1086,6 +1086,46 @@ async def test_acompletion_caching_on_router_caching_groups():
|
|||
asyncio.run(test_acompletion_caching_on_router_caching_groups())
|
||||
```
|
||||
|
||||
## Alerting 🚨
|
||||
|
||||
Send alerts to slack / your webhook url for the following events
|
||||
- LLM API Exceptions
|
||||
- Slow LLM Responses
|
||||
|
||||
Get a slack webhook url from https://api.slack.com/messaging/webhooks
|
||||
|
||||
#### Usage
|
||||
Initialize an `AlertingConfig` and pass it to `litellm.Router`. The following code will trigger an alert because `api_key=bad-key` which is invalid
|
||||
|
||||
```python
|
||||
from litellm.router import AlertingConfig
|
||||
import litellm
|
||||
import os
|
||||
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "bad_key",
|
||||
},
|
||||
}
|
||||
],
|
||||
alerting_config= AlertingConfig(
|
||||
alerting_threshold=10, # threshold for slow / hanging llm responses (in seconds). Defaults to 300 seconds
|
||||
webhook_url= os.getenv("SLACK_WEBHOOK_URL") # webhook you want to send alerts to
|
||||
),
|
||||
)
|
||||
try:
|
||||
await router.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
)
|
||||
except:
|
||||
pass
|
||||
```
|
||||
|
||||
## Track cost for Azure Deployments
|
||||
|
||||
**Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking
|
||||
|
|
|
@ -134,6 +134,7 @@ const sidebars = {
|
|||
"providers/ollama",
|
||||
"providers/perplexity",
|
||||
"providers/groq",
|
||||
"providers/deepseek",
|
||||
"providers/fireworks_ai",
|
||||
"providers/vllm",
|
||||
"providers/xinference",
|
||||
|
|
|
@ -361,6 +361,7 @@ openai_compatible_endpoints: List = [
|
|||
"api.deepinfra.com/v1/openai",
|
||||
"api.mistral.ai/v1",
|
||||
"api.groq.com/openai/v1",
|
||||
"api.deepseek.com/v1",
|
||||
"api.together.xyz/v1",
|
||||
]
|
||||
|
||||
|
@ -369,6 +370,7 @@ openai_compatible_providers: List = [
|
|||
"anyscale",
|
||||
"mistral",
|
||||
"groq",
|
||||
"deepseek",
|
||||
"deepinfra",
|
||||
"perplexity",
|
||||
"xinference",
|
||||
|
@ -523,6 +525,7 @@ provider_list: List = [
|
|||
"anyscale",
|
||||
"mistral",
|
||||
"groq",
|
||||
"deepseek",
|
||||
"maritalk",
|
||||
"voyage",
|
||||
"cloudflare",
|
||||
|
|
|
@ -262,7 +262,7 @@ class LangFuseLogger:
|
|||
|
||||
try:
|
||||
tags = []
|
||||
metadata = copy.deepcopy(metadata) # Avoid modifying the original metadata
|
||||
metadata = copy.deepcopy(metadata) # Avoid modifying the original metadata
|
||||
supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3")
|
||||
supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3")
|
||||
supports_costs = Version(langfuse.version.__version__) >= Version("2.7.3")
|
||||
|
@ -276,7 +276,6 @@ class LangFuseLogger:
|
|||
metadata_tags = metadata.pop("tags", [])
|
||||
tags = metadata_tags
|
||||
|
||||
|
||||
# Clean Metadata before logging - never log raw metadata
|
||||
# the raw metadata can contain circular references which leads to infinite recursion
|
||||
# we clean out all extra litellm metadata params before logging
|
||||
|
@ -303,18 +302,17 @@ class LangFuseLogger:
|
|||
else:
|
||||
clean_metadata[key] = value
|
||||
|
||||
|
||||
session_id = clean_metadata.pop("session_id", None)
|
||||
trace_name = clean_metadata.pop("trace_name", None)
|
||||
trace_id = clean_metadata.pop("trace_id", None)
|
||||
existing_trace_id = clean_metadata.pop("existing_trace_id", None)
|
||||
update_trace_keys = clean_metadata.pop("update_trace_keys", [])
|
||||
|
||||
|
||||
if trace_name is None and existing_trace_id is None:
|
||||
# just log `litellm-{call_type}` as the trace name
|
||||
## DO NOT SET TRACE_NAME if trace-id set. this can lead to overwriting of past traces.
|
||||
trace_name = f"litellm-{kwargs.get('call_type', 'completion')}"
|
||||
|
||||
|
||||
if existing_trace_id is not None:
|
||||
trace_params = {"id": existing_trace_id}
|
||||
|
||||
|
@ -322,15 +320,18 @@ class LangFuseLogger:
|
|||
for metadata_param_key in update_trace_keys:
|
||||
trace_param_key = metadata_param_key.replace("trace_", "")
|
||||
if trace_param_key not in trace_params:
|
||||
updated_trace_value = clean_metadata.pop(metadata_param_key, None)
|
||||
updated_trace_value = clean_metadata.pop(
|
||||
metadata_param_key, None
|
||||
)
|
||||
if updated_trace_value is not None:
|
||||
trace_params[trace_param_key] = updated_trace_value
|
||||
|
||||
|
||||
# Pop the trace specific keys that would have been popped if there were a new trace
|
||||
for key in list(filter(lambda key: key.startswith("trace_"), clean_metadata.keys())):
|
||||
for key in list(
|
||||
filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
|
||||
):
|
||||
clean_metadata.pop(key, None)
|
||||
|
||||
|
||||
# Special keys that are found in the function arguments and not the metadata
|
||||
if "input" in update_trace_keys:
|
||||
trace_params["input"] = input
|
||||
|
@ -342,16 +343,22 @@ class LangFuseLogger:
|
|||
"name": trace_name,
|
||||
"session_id": session_id,
|
||||
"input": input,
|
||||
"version": clean_metadata.pop("trace_version", clean_metadata.get("version", None)), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence
|
||||
"version": clean_metadata.pop(
|
||||
"trace_version", clean_metadata.get("version", None)
|
||||
), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence
|
||||
}
|
||||
for key in list(filter(lambda key: key.startswith("trace_"), clean_metadata.keys())):
|
||||
trace_params[key.replace("trace_", "")] = clean_metadata.pop(key, None)
|
||||
|
||||
for key in list(
|
||||
filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
|
||||
):
|
||||
trace_params[key.replace("trace_", "")] = clean_metadata.pop(
|
||||
key, None
|
||||
)
|
||||
|
||||
if level == "ERROR":
|
||||
trace_params["status_message"] = output
|
||||
else:
|
||||
trace_params["output"] = output
|
||||
|
||||
|
||||
cost = kwargs.get("response_cost", None)
|
||||
print_verbose(f"trace: {cost}")
|
||||
|
||||
|
@ -454,7 +461,7 @@ class LangFuseLogger:
|
|||
)
|
||||
|
||||
generation_client = trace.generation(**generation_params)
|
||||
|
||||
|
||||
return generation_client.trace_id, generation_id
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}")
|
||||
|
|
|
@ -68,11 +68,15 @@ class SlackAlertingCacheKeys(Enum):
|
|||
|
||||
|
||||
class SlackAlerting(CustomLogger):
|
||||
"""
|
||||
Class for sending Slack Alerts
|
||||
"""
|
||||
|
||||
# Class variables or attributes
|
||||
def __init__(
|
||||
self,
|
||||
internal_usage_cache: Optional[DualCache] = None,
|
||||
alerting_threshold: float = 300,
|
||||
alerting_threshold: float = 300, # threshold for slow / hanging llm responses (in seconds)
|
||||
alerting: Optional[List] = [],
|
||||
alert_types: Optional[
|
||||
List[
|
||||
|
@ -97,6 +101,7 @@ class SlackAlerting(CustomLogger):
|
|||
Dict
|
||||
] = None, # if user wants to separate alerts to diff channels
|
||||
alerting_args={},
|
||||
default_webhook_url: Optional[str] = None,
|
||||
):
|
||||
self.alerting_threshold = alerting_threshold
|
||||
self.alerting = alerting
|
||||
|
@ -106,6 +111,7 @@ class SlackAlerting(CustomLogger):
|
|||
self.alert_to_webhook_url = alert_to_webhook_url
|
||||
self.is_running = False
|
||||
self.alerting_args = SlackAlertingArgs(**alerting_args)
|
||||
self.default_webhook_url = default_webhook_url
|
||||
|
||||
def update_values(
|
||||
self,
|
||||
|
@ -149,16 +155,21 @@ class SlackAlerting(CustomLogger):
|
|||
|
||||
def _add_langfuse_trace_id_to_alert(
|
||||
self,
|
||||
request_info: str,
|
||||
request_data: Optional[dict] = None,
|
||||
kwargs: Optional[dict] = None,
|
||||
type: Literal["hanging_request", "slow_response"] = "hanging_request",
|
||||
start_time: Optional[datetime.datetime] = None,
|
||||
end_time: Optional[datetime.datetime] = None,
|
||||
):
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Returns langfuse trace url
|
||||
"""
|
||||
# do nothing for now
|
||||
pass
|
||||
return request_info
|
||||
if (
|
||||
request_data is not None
|
||||
and request_data.get("metadata", {}).get("trace_id", None) is not None
|
||||
):
|
||||
trace_id = request_data["metadata"]["trace_id"]
|
||||
if litellm.utils.langFuseLogger is not None:
|
||||
base_url = litellm.utils.langFuseLogger.Langfuse.base_url
|
||||
return f"{base_url}/trace/{trace_id}"
|
||||
return None
|
||||
|
||||
def _response_taking_too_long_callback_helper(
|
||||
self,
|
||||
|
@ -302,7 +313,7 @@ class SlackAlerting(CustomLogger):
|
|||
except Exception as e:
|
||||
return 0
|
||||
|
||||
async def send_daily_reports(self, router: litellm.Router) -> bool:
|
||||
async def send_daily_reports(self, router) -> bool:
|
||||
"""
|
||||
Send a daily report on:
|
||||
- Top 5 deployments with most failed requests
|
||||
|
@ -501,14 +512,13 @@ class SlackAlerting(CustomLogger):
|
|||
)
|
||||
|
||||
if "langfuse" in litellm.success_callback:
|
||||
request_info = self._add_langfuse_trace_id_to_alert(
|
||||
request_info=request_info,
|
||||
langfuse_url = self._add_langfuse_trace_id_to_alert(
|
||||
request_data=request_data,
|
||||
type="hanging_request",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
if langfuse_url is not None:
|
||||
request_info += "\n🪢 Langfuse Trace: {}".format(langfuse_url)
|
||||
|
||||
# add deployment latencies to alert
|
||||
_deployment_latency_map = self._get_deployment_latencies_to_alert(
|
||||
metadata=request_data.get("metadata", {})
|
||||
|
@ -701,6 +711,7 @@ Model Info:
|
|||
"daily_reports",
|
||||
"new_model_added",
|
||||
],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
|
||||
|
@ -731,6 +742,10 @@ Model Info:
|
|||
formatted_message = (
|
||||
f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}"
|
||||
)
|
||||
|
||||
if kwargs:
|
||||
for key, value in kwargs.items():
|
||||
formatted_message += f"\n\n{key}: `{value}`\n\n"
|
||||
if _proxy_base_url is not None:
|
||||
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
|
||||
|
||||
|
@ -740,6 +755,8 @@ Model Info:
|
|||
and alert_type in self.alert_to_webhook_url
|
||||
):
|
||||
slack_webhook_url = self.alert_to_webhook_url[alert_type]
|
||||
elif self.default_webhook_url is not None:
|
||||
slack_webhook_url = self.default_webhook_url
|
||||
else:
|
||||
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
|
||||
|
||||
|
@ -796,8 +813,16 @@ Model Info:
|
|||
updated_at=litellm.utils.get_utc_datetime(),
|
||||
)
|
||||
)
|
||||
if "llm_exceptions" in self.alert_types:
|
||||
original_exception = kwargs.get("exception", None)
|
||||
|
||||
async def _run_scheduler_helper(self, llm_router: litellm.Router) -> bool:
|
||||
await self.send_alert(
|
||||
message="LLM API Failure - " + str(original_exception),
|
||||
level="High",
|
||||
alert_type="llm_exceptions",
|
||||
)
|
||||
|
||||
async def _run_scheduler_helper(self, llm_router) -> bool:
|
||||
"""
|
||||
Returns:
|
||||
- True -> report sent
|
||||
|
@ -839,7 +864,7 @@ Model Info:
|
|||
|
||||
return report_sent_bool
|
||||
|
||||
async def _run_scheduled_daily_report(self, llm_router: Optional[litellm.Router]):
|
||||
async def _run_scheduled_daily_report(self, llm_router: Optional[Any] = None):
|
||||
"""
|
||||
If 'daily_reports' enabled
|
||||
|
||||
|
|
|
@ -474,3 +474,23 @@ async def ollama_aembeddings(
|
|||
"total_tokens": total_input_tokens,
|
||||
}
|
||||
return model_response
|
||||
|
||||
def ollama_embeddings(
|
||||
api_base: str,
|
||||
model: str,
|
||||
prompts: list,
|
||||
optional_params=None,
|
||||
logging_obj=None,
|
||||
model_response=None,
|
||||
encoding=None,
|
||||
):
|
||||
return asyncio.run(
|
||||
ollama_aembeddings(
|
||||
api_base,
|
||||
model,
|
||||
prompts,
|
||||
optional_params,
|
||||
logging_obj,
|
||||
model_response,
|
||||
encoding)
|
||||
)
|
||||
|
|
|
@ -305,6 +305,7 @@ async def acompletion(
|
|||
or custom_llm_provider == "deepinfra"
|
||||
or custom_llm_provider == "perplexity"
|
||||
or custom_llm_provider == "groq"
|
||||
or custom_llm_provider == "deepseek"
|
||||
or custom_llm_provider == "text-completion-openai"
|
||||
or custom_llm_provider == "huggingface"
|
||||
or custom_llm_provider == "ollama"
|
||||
|
@ -982,6 +983,7 @@ def completion(
|
|||
or custom_llm_provider == "deepinfra"
|
||||
or custom_llm_provider == "perplexity"
|
||||
or custom_llm_provider == "groq"
|
||||
or custom_llm_provider == "deepseek"
|
||||
or custom_llm_provider == "anyscale"
|
||||
or custom_llm_provider == "mistral"
|
||||
or custom_llm_provider == "openai"
|
||||
|
@ -2168,7 +2170,7 @@ def completion(
|
|||
"""
|
||||
assume input to custom LLM api bases follow this format:
|
||||
resp = requests.post(
|
||||
api_base,
|
||||
api_base,
|
||||
json={
|
||||
'model': 'meta-llama/Llama-2-13b-hf', # model name
|
||||
'params': {
|
||||
|
@ -2565,6 +2567,7 @@ async def aembedding(*args, **kwargs):
|
|||
or custom_llm_provider == "deepinfra"
|
||||
or custom_llm_provider == "perplexity"
|
||||
or custom_llm_provider == "groq"
|
||||
or custom_llm_provider == "deepseek"
|
||||
or custom_llm_provider == "fireworks_ai"
|
||||
or custom_llm_provider == "ollama"
|
||||
or custom_llm_provider == "vertex_ai"
|
||||
|
@ -2947,16 +2950,16 @@ def embedding(
|
|||
model=model, # type: ignore
|
||||
llm_provider="ollama", # type: ignore
|
||||
)
|
||||
if aembedding:
|
||||
response = ollama.ollama_aembeddings(
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
prompts=input,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
)
|
||||
ollama_embeddings_fn = ollama.ollama_aembeddings if aembedding else ollama.ollama_embeddings
|
||||
response = ollama_embeddings_fn(
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
prompts=input,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
)
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
response = sagemaker.embedding(
|
||||
model=model,
|
||||
|
@ -3085,6 +3088,7 @@ async def atext_completion(*args, **kwargs):
|
|||
or custom_llm_provider == "deepinfra"
|
||||
or custom_llm_provider == "perplexity"
|
||||
or custom_llm_provider == "groq"
|
||||
or custom_llm_provider == "deepseek"
|
||||
or custom_llm_provider == "fireworks_ai"
|
||||
or custom_llm_provider == "text-completion-openai"
|
||||
or custom_llm_provider == "huggingface"
|
||||
|
|
|
@ -739,6 +739,24 @@
|
|||
"litellm_provider": "mistral",
|
||||
"mode": "embedding"
|
||||
},
|
||||
"deepseek-chat": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 32000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000014,
|
||||
"output_cost_per_token": 0.00000028,
|
||||
"litellm_provider": "deepseek",
|
||||
"mode": "chat"
|
||||
},
|
||||
"deepseek-coder": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 16000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000014,
|
||||
"output_cost_per_token": 0.00000028,
|
||||
"litellm_provider": "deepseek",
|
||||
"mode": "chat"
|
||||
},
|
||||
"groq/llama2-70b-4096": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 4096,
|
||||
|
|
|
@ -14,6 +14,9 @@ model_list:
|
|||
api_key: my-fake-key-3
|
||||
model: openai/my-fake-model-3
|
||||
model_name: fake-openai-endpoint
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
router_settings:
|
||||
num_retries: 0
|
||||
enable_pre_call_checks: true
|
||||
|
@ -25,7 +28,7 @@ router_settings:
|
|||
routing_strategy: "latency-based-routing"
|
||||
|
||||
litellm_settings:
|
||||
success_callback: ["openmeter"]
|
||||
success_callback: ["langfuse"]
|
||||
|
||||
general_settings:
|
||||
alerting: ["slack"]
|
||||
|
|
|
@ -2531,6 +2531,7 @@ class ProxyConfig:
|
|||
if "db_model" in model.model_info and model.model_info["db_model"] == False:
|
||||
model.model_info["db_model"] = db_model
|
||||
_model_info = RouterModelInfo(**model.model_info)
|
||||
|
||||
else:
|
||||
_model_info = RouterModelInfo(id=model.model_id, db_model=db_model)
|
||||
return _model_info
|
||||
|
@ -3175,7 +3176,9 @@ def data_generator(response):
|
|||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
|
||||
|
||||
async def async_data_generator(response, user_api_key_dict):
|
||||
async def async_data_generator(
|
||||
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
||||
):
|
||||
verbose_proxy_logger.debug("inside generator")
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
@ -3192,7 +3195,9 @@ async def async_data_generator(response, user_api_key_dict):
|
|||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data=request_data,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
||||
|
@ -3217,8 +3222,14 @@ async def async_data_generator(response, user_api_key_dict):
|
|||
yield f"data: {error_returned}\n\n"
|
||||
|
||||
|
||||
def select_data_generator(response, user_api_key_dict):
|
||||
return async_data_generator(response=response, user_api_key_dict=user_api_key_dict)
|
||||
def select_data_generator(
|
||||
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
||||
):
|
||||
return async_data_generator(
|
||||
response=response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=request_data,
|
||||
)
|
||||
|
||||
|
||||
def get_litellm_model_info(model: dict = {}):
|
||||
|
@ -3513,9 +3524,8 @@ async def chat_completion(
|
|||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
||||
data = {}
|
||||
try:
|
||||
# async with llm_router.sem
|
||||
data = {}
|
||||
body = await request.body()
|
||||
body_str = body.decode()
|
||||
try:
|
||||
|
@ -3706,7 +3716,9 @@ async def chat_completion(
|
|||
"x-litellm-model-api-base": api_base,
|
||||
}
|
||||
selected_data_generator = select_data_generator(
|
||||
response=response, user_api_key_dict=user_api_key_dict
|
||||
response=response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=data,
|
||||
)
|
||||
return StreamingResponse(
|
||||
selected_data_generator,
|
||||
|
@ -3728,7 +3740,7 @@ async def chat_completion(
|
|||
data["litellm_status"] = "fail" # used for alerting
|
||||
traceback.print_exc()
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
||||
|
@ -3890,7 +3902,9 @@ async def completion(
|
|||
"x-litellm-model-id": model_id,
|
||||
}
|
||||
selected_data_generator = select_data_generator(
|
||||
response=response, user_api_key_dict=user_api_key_dict
|
||||
response=response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=data,
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
|
@ -3943,6 +3957,7 @@ async def embeddings(
|
|||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
global proxy_logging_obj
|
||||
data: Any = {}
|
||||
try:
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
|
@ -4088,7 +4103,7 @@ async def embeddings(
|
|||
except Exception as e:
|
||||
data["litellm_status"] = "fail" # used for alerting
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
traceback.print_exc()
|
||||
if isinstance(e, HTTPException):
|
||||
|
@ -4125,6 +4140,7 @@ async def image_generation(
|
|||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
global proxy_logging_obj
|
||||
data = {}
|
||||
try:
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
|
@ -4244,7 +4260,7 @@ async def image_generation(
|
|||
except Exception as e:
|
||||
data["litellm_status"] = "fail" # used for alerting
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
traceback.print_exc()
|
||||
if isinstance(e, HTTPException):
|
||||
|
@ -4285,10 +4301,11 @@ async def audio_transcriptions(
|
|||
https://platform.openai.com/docs/api-reference/audio/createTranscription?lang=curl
|
||||
"""
|
||||
global proxy_logging_obj
|
||||
data: Dict = {}
|
||||
try:
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
form_data = await request.form()
|
||||
data: Dict = {key: value for key, value in form_data.items() if key != "file"}
|
||||
data = {key: value for key, value in form_data.items() if key != "file"}
|
||||
|
||||
# Include original request and headers in the data
|
||||
data["proxy_server_request"] = { # type: ignore
|
||||
|
@ -4423,7 +4440,7 @@ async def audio_transcriptions(
|
|||
except Exception as e:
|
||||
data["litellm_status"] = "fail" # used for alerting
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
traceback.print_exc()
|
||||
if isinstance(e, HTTPException):
|
||||
|
@ -4472,6 +4489,7 @@ async def moderations(
|
|||
```
|
||||
"""
|
||||
global proxy_logging_obj
|
||||
data: Dict = {}
|
||||
try:
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
|
@ -4585,7 +4603,7 @@ async def moderations(
|
|||
except Exception as e:
|
||||
data["litellm_status"] = "fail" # used for alerting
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
traceback.print_exc()
|
||||
if isinstance(e, HTTPException):
|
||||
|
@ -8048,8 +8066,8 @@ async def async_queue_request(
|
|||
|
||||
Now using a FastAPI background task + /chat/completions compatible endpoint
|
||||
"""
|
||||
data = {}
|
||||
try:
|
||||
data = {}
|
||||
data = await request.json() # type: ignore
|
||||
|
||||
# Include original request and headers in the data
|
||||
|
@ -8114,7 +8132,9 @@ async def async_queue_request(
|
|||
): # use generate_responses to stream responses
|
||||
return StreamingResponse(
|
||||
async_data_generator(
|
||||
user_api_key_dict=user_api_key_dict, response=response
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
response=response,
|
||||
request_data=data,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
@ -8122,7 +8142,7 @@ async def async_queue_request(
|
|||
return response
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
|
|
|
@ -302,6 +302,7 @@ class ProxyLogging:
|
|||
"budget_alerts",
|
||||
"db_exceptions",
|
||||
],
|
||||
request_data: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
|
||||
|
@ -331,10 +332,19 @@ class ProxyLogging:
|
|||
if _proxy_base_url is not None:
|
||||
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
|
||||
|
||||
extra_kwargs = {}
|
||||
if request_data is not None:
|
||||
_url = self.slack_alerting_instance._add_langfuse_trace_id_to_alert(
|
||||
request_data=request_data
|
||||
)
|
||||
if _url is not None:
|
||||
extra_kwargs["🪢 Langfuse Trace"] = _url
|
||||
formatted_message += "\n\n🪢 Langfuse Trace: {}".format(_url)
|
||||
|
||||
for client in self.alerting:
|
||||
if client == "slack":
|
||||
await self.slack_alerting_instance.send_alert(
|
||||
message=message, level=level, alert_type=alert_type
|
||||
message=message, level=level, alert_type=alert_type, **extra_kwargs
|
||||
)
|
||||
elif client == "sentry":
|
||||
if litellm.utils.sentry_sdk_instance is not None:
|
||||
|
@ -369,6 +379,7 @@ class ProxyLogging:
|
|||
message=f"DB read/write call failed: {error_message}",
|
||||
level="High",
|
||||
alert_type="db_exceptions",
|
||||
request_data={},
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -384,7 +395,10 @@ class ProxyLogging:
|
|||
litellm.utils.capture_exception(error=original_exception)
|
||||
|
||||
async def post_call_failure_hook(
|
||||
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
|
||||
self,
|
||||
original_exception: Exception,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
request_data: dict,
|
||||
):
|
||||
"""
|
||||
Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body.
|
||||
|
@ -409,6 +423,7 @@ class ProxyLogging:
|
|||
message=f"LLM API call failed: {str(original_exception)}",
|
||||
level="High",
|
||||
alert_type="llm_exceptions",
|
||||
request_data=request_data,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -44,6 +44,7 @@ from litellm.types.router import (
|
|||
updateDeployment,
|
||||
updateLiteLLMParams,
|
||||
RetryPolicy,
|
||||
AlertingConfig,
|
||||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
|
@ -103,6 +104,7 @@ class Router:
|
|||
] = "simple-shuffle",
|
||||
routing_strategy_args: dict = {}, # just for latency-based routing
|
||||
semaphore: Optional[asyncio.Semaphore] = None,
|
||||
alerting_config: Optional[AlertingConfig] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Router class with the given parameters for caching, reliability, and routing strategy.
|
||||
|
@ -131,7 +133,7 @@ class Router:
|
|||
cooldown_time (float): Time to cooldown a deployment after failure in seconds. Defaults to 1.
|
||||
routing_strategy (Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing", "cost-based-routing"]): Routing strategy. Defaults to "simple-shuffle".
|
||||
routing_strategy_args (dict): Additional args for latency-based routing. Defaults to {}.
|
||||
|
||||
alerting_config (AlertingConfig): Slack alerting configuration. Defaults to None.
|
||||
Returns:
|
||||
Router: An instance of the litellm.Router class.
|
||||
|
||||
|
@ -316,6 +318,9 @@ class Router:
|
|||
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
|
||||
model_group_retry_policy
|
||||
)
|
||||
self.alerting_config: Optional[AlertingConfig] = alerting_config
|
||||
if self.alerting_config is not None:
|
||||
self._initialize_alerting()
|
||||
|
||||
def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
|
||||
if routing_strategy == "least-busy":
|
||||
|
@ -3000,6 +3005,7 @@ class Router:
|
|||
if (
|
||||
self.routing_strategy != "usage-based-routing-v2"
|
||||
and self.routing_strategy != "simple-shuffle"
|
||||
and self.routing_strategy != "cost-based-routing"
|
||||
): # prevent regressions for other routing strategies, that don't have async get available deployments implemented.
|
||||
return self.get_available_deployment(
|
||||
model=model,
|
||||
|
@ -3056,6 +3062,16 @@ class Router:
|
|||
messages=messages,
|
||||
input=input,
|
||||
)
|
||||
if (
|
||||
self.routing_strategy == "cost-based-routing"
|
||||
and self.lowestcost_logger is not None
|
||||
):
|
||||
deployment = await self.lowestcost_logger.async_get_available_deployments(
|
||||
model_group=model,
|
||||
healthy_deployments=healthy_deployments,
|
||||
messages=messages,
|
||||
input=input,
|
||||
)
|
||||
elif self.routing_strategy == "simple-shuffle":
|
||||
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
|
||||
############## Check if we can do a RPM/TPM based weighted pick #################
|
||||
|
@ -3226,15 +3242,6 @@ class Router:
|
|||
messages=messages,
|
||||
input=input,
|
||||
)
|
||||
elif (
|
||||
self.routing_strategy == "cost-based-routing"
|
||||
and self.lowestcost_logger is not None
|
||||
):
|
||||
deployment = self.lowestcost_logger.get_available_deployments(
|
||||
model_group=model,
|
||||
healthy_deployments=healthy_deployments,
|
||||
request_kwargs=request_kwargs,
|
||||
)
|
||||
if deployment is None:
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, No deployment available"
|
||||
|
@ -3360,6 +3367,23 @@ class Router:
|
|||
):
|
||||
return retry_policy.ContentPolicyViolationErrorRetries
|
||||
|
||||
def _initialize_alerting(self):
|
||||
from litellm.integrations.slack_alerting import SlackAlerting
|
||||
|
||||
router_alerting_config: AlertingConfig = self.alerting_config
|
||||
|
||||
_slack_alerting_logger = SlackAlerting(
|
||||
alerting_threshold=router_alerting_config.alerting_threshold,
|
||||
alerting=["slack"],
|
||||
default_webhook_url=router_alerting_config.webhook_url,
|
||||
)
|
||||
|
||||
litellm.callbacks.append(_slack_alerting_logger)
|
||||
litellm.success_callback.append(
|
||||
_slack_alerting_logger.response_taking_too_long_callback
|
||||
)
|
||||
print("\033[94m\nInitialized Alerting for litellm.Router\033[0m\n") # noqa
|
||||
|
||||
def flush_cache(self):
|
||||
litellm.cache = None
|
||||
self.cache.flush_cache()
|
||||
|
|
|
@ -40,7 +40,7 @@ class LowestCostLoggingHandler(CustomLogger):
|
|||
self.router_cache = router_cache
|
||||
self.model_list = model_list
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
async def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
"""
|
||||
Update usage on success
|
||||
|
@ -90,7 +90,11 @@ class LowestCostLoggingHandler(CustomLogger):
|
|||
# Update usage
|
||||
# ------------
|
||||
|
||||
request_count_dict = self.router_cache.get_cache(key=cost_key) or {}
|
||||
request_count_dict = (
|
||||
await self.router_cache.async_get_cache(key=cost_key) or {}
|
||||
)
|
||||
|
||||
# check local result first
|
||||
|
||||
if id not in request_count_dict:
|
||||
request_count_dict[id] = {}
|
||||
|
@ -111,7 +115,9 @@ class LowestCostLoggingHandler(CustomLogger):
|
|||
request_count_dict[id][precise_minute].get("rpm", 0) + 1
|
||||
)
|
||||
|
||||
self.router_cache.set_cache(key=cost_key, value=request_count_dict)
|
||||
await self.router_cache.async_set_cache(
|
||||
key=cost_key, value=request_count_dict
|
||||
)
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
|
@ -172,7 +178,9 @@ class LowestCostLoggingHandler(CustomLogger):
|
|||
# Update usage
|
||||
# ------------
|
||||
|
||||
request_count_dict = self.router_cache.get_cache(key=cost_key) or {}
|
||||
request_count_dict = (
|
||||
await self.router_cache.async_get_cache(key=cost_key) or {}
|
||||
)
|
||||
|
||||
if id not in request_count_dict:
|
||||
request_count_dict[id] = {}
|
||||
|
@ -189,7 +197,7 @@ class LowestCostLoggingHandler(CustomLogger):
|
|||
request_count_dict[id][precise_minute].get("rpm", 0) + 1
|
||||
)
|
||||
|
||||
self.router_cache.set_cache(
|
||||
await self.router_cache.async_set_cache(
|
||||
key=cost_key, value=request_count_dict
|
||||
) # reset map within window
|
||||
|
||||
|
@ -200,7 +208,7 @@ class LowestCostLoggingHandler(CustomLogger):
|
|||
traceback.print_exc()
|
||||
pass
|
||||
|
||||
def get_available_deployments(
|
||||
async def async_get_available_deployments(
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
|
@ -213,7 +221,7 @@ class LowestCostLoggingHandler(CustomLogger):
|
|||
"""
|
||||
cost_key = f"{model_group}_map"
|
||||
|
||||
request_count_dict = self.router_cache.get_cache(key=cost_key) or {}
|
||||
request_count_dict = await self.router_cache.async_get_cache(key=cost_key) or {}
|
||||
|
||||
# -----------------------
|
||||
# Find lowest used model
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -18,6 +18,10 @@ from unittest.mock import patch, MagicMock
|
|||
from litellm.utils import get_api_base
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.slack_alerting import SlackAlerting, DeploymentMetrics
|
||||
import unittest.mock
|
||||
from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
from litellm.router import AlertingConfig, Router
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -313,3 +317,45 @@ async def test_daily_reports_redis_cache_scheduler():
|
|||
|
||||
# second call - expect empty
|
||||
await slack_alerting._run_scheduler_helper(llm_router=router)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Local test. Test if slack alerts are sent.")
|
||||
async def test_send_llm_exception_to_slack():
|
||||
from litellm.router import AlertingConfig
|
||||
|
||||
# on async success
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "bad_key",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-5-good",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
},
|
||||
},
|
||||
],
|
||||
alerting_config=AlertingConfig(
|
||||
alerting_threshold=0.5, webhook_url=os.getenv("SLACK_WEBHOOK_URL")
|
||||
),
|
||||
)
|
||||
try:
|
||||
await router.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
await router.acompletion(
|
||||
model="gpt-5-good",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
)
|
||||
|
||||
await asyncio.sleep(3)
|
||||
|
|
|
@ -1061,16 +1061,16 @@ def test_completion_perplexity_api_2():
|
|||
######### HUGGING FACE TESTS ########################
|
||||
#####################################################
|
||||
"""
|
||||
HF Tests we should pass
|
||||
- TGI:
|
||||
- Pro Inference API
|
||||
- Deployed Endpoint
|
||||
- Coversational
|
||||
- Free Inference API
|
||||
- Deployed Endpoint
|
||||
HF Tests we should pass
|
||||
- TGI:
|
||||
- Pro Inference API
|
||||
- Deployed Endpoint
|
||||
- Coversational
|
||||
- Free Inference API
|
||||
- Deployed Endpoint
|
||||
- Neither TGI or Coversational
|
||||
- Free Inference API
|
||||
- Deployed Endpoint
|
||||
- Free Inference API
|
||||
- Deployed Endpoint
|
||||
"""
|
||||
|
||||
|
||||
|
@ -2168,9 +2168,9 @@ def test_completion_replicate_vicuna():
|
|||
|
||||
def test_replicate_custom_prompt_dict():
|
||||
litellm.set_verbose = True
|
||||
model_name = "replicate/meta/llama-2-70b-chat"
|
||||
model_name = "replicate/meta/llama-2-7b"
|
||||
litellm.register_prompt_template(
|
||||
model="replicate/meta/llama-2-70b-chat",
|
||||
model="replicate/meta/llama-2-7b",
|
||||
initial_prompt_value="You are a good assistant", # [OPTIONAL]
|
||||
roles={
|
||||
"system": {
|
||||
|
@ -2200,6 +2200,7 @@ def test_replicate_custom_prompt_dict():
|
|||
repetition_penalty=0.1,
|
||||
num_retries=3,
|
||||
)
|
||||
|
||||
except litellm.APIError as e:
|
||||
pass
|
||||
except litellm.APIConnectionError as e:
|
||||
|
@ -3017,6 +3018,21 @@ async def test_acompletion_gemini():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# Deepseek tests
|
||||
def test_completion_deepseek():
|
||||
litellm.set_verbose = True
|
||||
model_name = "deepseek/deepseek-chat"
|
||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
try:
|
||||
response = completion(model=model_name, messages=messages)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except litellm.APIError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# Palm tests
|
||||
def test_completion_palm():
|
||||
litellm.set_verbose = True
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
import litellm
|
||||
from litellm import get_optional_params
|
||||
|
||||
litellm.add_function_to_prompt = True
|
||||
optional_params = get_optional_params(
|
||||
tools= [{'type': 'function', 'function': {'description': 'Get the current weather in a given location', 'name': 'get_current_weather', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state, e.g. San Francisco, CA'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}],
|
||||
tool_choice= 'auto',
|
||||
)
|
||||
assert optional_params is not None
|
|
@ -20,7 +20,8 @@ from litellm.caching import DualCache
|
|||
### UNIT TESTS FOR cost ROUTING ###
|
||||
|
||||
|
||||
def test_get_available_deployments():
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_available_deployments():
|
||||
test_cache = DualCache()
|
||||
model_list = [
|
||||
{
|
||||
|
@ -40,7 +41,7 @@ def test_get_available_deployments():
|
|||
model_group = "gpt-3.5-turbo"
|
||||
|
||||
## CHECK WHAT'S SELECTED ##
|
||||
selected_model = lowest_cost_logger.get_available_deployments(
|
||||
selected_model = await lowest_cost_logger.async_get_available_deployments(
|
||||
model_group=model_group, healthy_deployments=model_list
|
||||
)
|
||||
print("selected model: ", selected_model)
|
||||
|
@ -48,7 +49,8 @@ def test_get_available_deployments():
|
|||
assert selected_model["model_info"]["id"] == "groq-llama"
|
||||
|
||||
|
||||
def test_get_available_deployments_custom_price():
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_available_deployments_custom_price():
|
||||
from litellm._logging import verbose_router_logger
|
||||
import logging
|
||||
|
||||
|
@ -89,7 +91,7 @@ def test_get_available_deployments_custom_price():
|
|||
model_group = "gpt-3.5-turbo"
|
||||
|
||||
## CHECK WHAT'S SELECTED ##
|
||||
selected_model = lowest_cost_logger.get_available_deployments(
|
||||
selected_model = await lowest_cost_logger.async_get_available_deployments(
|
||||
model_group=model_group, healthy_deployments=model_list
|
||||
)
|
||||
print("selected model: ", selected_model)
|
||||
|
@ -142,7 +144,7 @@ async def _deploy(lowest_cost_logger, deployment_id, tokens_used, duration):
|
|||
response_obj = {"usage": {"total_tokens": tokens_used}}
|
||||
time.sleep(duration)
|
||||
end_time = time.time()
|
||||
lowest_cost_logger.log_success_event(
|
||||
await lowest_cost_logger.async_log_success_event(
|
||||
response_obj=response_obj,
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
|
@ -150,14 +152,11 @@ async def _deploy(lowest_cost_logger, deployment_id, tokens_used, duration):
|
|||
)
|
||||
|
||||
|
||||
async def _gather_deploy(all_deploys):
|
||||
return await asyncio.gather(*[_deploy(*t) for t in all_deploys])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ans_rpm", [1, 5]
|
||||
) # 1 should produce nothing, 10 should select first
|
||||
def test_get_available_endpoints_tpm_rpm_check_async(ans_rpm):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_available_endpoints_tpm_rpm_check_async(ans_rpm):
|
||||
"""
|
||||
Pass in list of 2 valid models
|
||||
|
||||
|
@ -193,9 +192,13 @@ def test_get_available_endpoints_tpm_rpm_check_async(ans_rpm):
|
|||
model_group = "gpt-3.5-turbo"
|
||||
d1 = [(lowest_cost_logger, "1234", 50, 0.01)] * non_ans_rpm
|
||||
d2 = [(lowest_cost_logger, "5678", 50, 0.01)] * non_ans_rpm
|
||||
asyncio.run(_gather_deploy([*d1, *d2]))
|
||||
|
||||
await asyncio.gather(*[_deploy(*t) for t in [*d1, *d2]])
|
||||
|
||||
asyncio.sleep(3)
|
||||
|
||||
## CHECK WHAT'S SELECTED ##
|
||||
d_ans = lowest_cost_logger.get_available_deployments(
|
||||
d_ans = await lowest_cost_logger.async_get_available_deployments(
|
||||
model_group=model_group, healthy_deployments=model_list
|
||||
)
|
||||
assert (d_ans and d_ans["model_info"]["id"]) == ans
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
@ -10,10 +11,10 @@ sys.path.insert(
|
|||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
|
||||
from unittest import mock
|
||||
|
||||
## for ollama we can't test making the completion call
|
||||
from litellm.utils import get_optional_params, get_llm_provider
|
||||
from litellm.utils import EmbeddingResponse, get_optional_params, get_llm_provider
|
||||
|
||||
|
||||
def test_get_ollama_params():
|
||||
|
@ -58,3 +59,50 @@ def test_ollama_json_mode():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_ollama_json_mode()
|
||||
|
||||
|
||||
mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-text")
|
||||
|
||||
@mock.patch(
|
||||
"litellm.llms.ollama.ollama_embeddings",
|
||||
return_value=mock_ollama_embedding_response,
|
||||
)
|
||||
def test_ollama_embeddings(mock_embeddings):
|
||||
# assert that ollama_embeddings is called with the right parameters
|
||||
try:
|
||||
embeddings = litellm.embedding(model="ollama/nomic-embed-text", input=["hello world"])
|
||||
print(embeddings)
|
||||
mock_embeddings.assert_called_once_with(
|
||||
api_base="http://localhost:11434",
|
||||
model="nomic-embed-text",
|
||||
prompts=["hello world"],
|
||||
optional_params=mock.ANY,
|
||||
logging_obj=mock.ANY,
|
||||
model_response=mock.ANY,
|
||||
encoding=mock.ANY,
|
||||
)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_ollama_embeddings()
|
||||
|
||||
@mock.patch(
|
||||
"litellm.llms.ollama.ollama_aembeddings",
|
||||
return_value=mock_ollama_embedding_response,
|
||||
)
|
||||
def test_ollama_aembeddings(mock_aembeddings):
|
||||
# assert that ollama_aembeddings is called with the right parameters
|
||||
try:
|
||||
embeddings = asyncio.run(litellm.aembedding(model="ollama/nomic-embed-text", input=["hello world"]))
|
||||
print(embeddings)
|
||||
mock_aembeddings.assert_called_once_with(
|
||||
api_base="http://localhost:11434",
|
||||
model="nomic-embed-text",
|
||||
prompts=["hello world"],
|
||||
optional_params=mock.ANY,
|
||||
logging_obj=mock.ANY,
|
||||
model_response=mock.ANY,
|
||||
encoding=mock.ANY,
|
||||
)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_ollama_aembeddings()
|
||||
|
|
|
@ -24,6 +24,14 @@
|
|||
|
||||
# asyncio.run(test_ollama_aembeddings())
|
||||
|
||||
# def test_ollama_embeddings():
|
||||
# litellm.set_verbose = True
|
||||
# input = "The food was delicious and the waiter..."
|
||||
# response = litellm.embedding(model="ollama/mistral", input=input)
|
||||
# print(response)
|
||||
|
||||
# test_ollama_embeddings()
|
||||
|
||||
# def test_ollama_streaming():
|
||||
# try:
|
||||
# litellm.set_verbose = False
|
||||
|
|
|
@ -340,3 +340,19 @@ class RetryPolicy(BaseModel):
|
|||
RateLimitErrorRetries: Optional[int] = None
|
||||
ContentPolicyViolationErrorRetries: Optional[int] = None
|
||||
InternalServerErrorRetries: Optional[int] = None
|
||||
|
||||
|
||||
class AlertingConfig(BaseModel):
|
||||
"""
|
||||
Use this configure alerting for the router. Receive alerts on the following events
|
||||
- LLM API Exceptions
|
||||
- LLM Responses Too Slow
|
||||
- LLM Requests Hanging
|
||||
|
||||
Args:
|
||||
webhook_url: str - webhook url for alerting, slack provides a webhook url to send alerts to
|
||||
alerting_threshold: Optional[float] = None - threshold for slow / hanging llm responses (in seconds)
|
||||
"""
|
||||
|
||||
webhook_url: str
|
||||
alerting_threshold: Optional[float] = 300
|
||||
|
|
|
@ -4836,7 +4836,7 @@ def get_optional_params(
|
|||
**kwargs,
|
||||
):
|
||||
# retrieve all parameters passed to the function
|
||||
passed_params = locals()
|
||||
passed_params = locals().copy()
|
||||
special_params = passed_params.pop("kwargs")
|
||||
for k, v in special_params.items():
|
||||
if k.startswith("aws_") and (
|
||||
|
@ -4929,9 +4929,11 @@ def get_optional_params(
|
|||
and custom_llm_provider != "anyscale"
|
||||
and custom_llm_provider != "together_ai"
|
||||
and custom_llm_provider != "groq"
|
||||
and custom_llm_provider != "deepseek"
|
||||
and custom_llm_provider != "mistral"
|
||||
and custom_llm_provider != "anthropic"
|
||||
and custom_llm_provider != "cohere_chat"
|
||||
and custom_llm_provider != "cohere"
|
||||
and custom_llm_provider != "bedrock"
|
||||
and custom_llm_provider != "ollama_chat"
|
||||
):
|
||||
|
@ -4956,7 +4958,7 @@ def get_optional_params(
|
|||
litellm.add_function_to_prompt
|
||||
): # if user opts to add it to prompt instead
|
||||
optional_params["functions_unsupported_model"] = non_default_params.pop(
|
||||
"tools", non_default_params.pop("functions")
|
||||
"tools", non_default_params.pop("functions", None)
|
||||
)
|
||||
else:
|
||||
raise UnsupportedParamsError(
|
||||
|
@ -5614,6 +5616,29 @@ def get_optional_params(
|
|||
if seed is not None:
|
||||
optional_params["seed"] = seed
|
||||
|
||||
elif custom_llm_provider == "deepseek":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
|
||||
if frequency_penalty is not None:
|
||||
optional_params["frequency_penalty"] = frequency_penalty
|
||||
if max_tokens is not None:
|
||||
optional_params["max_tokens"] = max_tokens
|
||||
if presence_penalty is not None:
|
||||
optional_params["presence_penalty"] = presence_penalty
|
||||
if stop is not None:
|
||||
optional_params["stop"] = stop
|
||||
if stream is not None:
|
||||
optional_params["stream"] = stream
|
||||
if temperature is not None:
|
||||
optional_params["temperature"] = temperature
|
||||
if logprobs is not None:
|
||||
optional_params["logprobs"] = logprobs
|
||||
if top_logprobs is not None:
|
||||
optional_params["top_logprobs"] = top_logprobs
|
||||
|
||||
elif custom_llm_provider == "openrouter":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
|
@ -5946,6 +5971,19 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
|||
"response_format",
|
||||
"seed",
|
||||
]
|
||||
elif custom_llm_provider == "deepseek":
|
||||
return [
|
||||
# https://platform.deepseek.com/api-docs/api/create-chat-completion
|
||||
"frequency_penalty",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"stop",
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
]
|
||||
elif custom_llm_provider == "cohere":
|
||||
return [
|
||||
"stream",
|
||||
|
@ -6239,8 +6277,12 @@ def get_llm_provider(
|
|||
# groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
|
||||
api_base = "https://api.groq.com/openai/v1"
|
||||
dynamic_api_key = get_secret("GROQ_API_KEY")
|
||||
elif custom_llm_provider == "deepseek":
|
||||
# deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1
|
||||
api_base = "https://api.deepseek.com/v1"
|
||||
dynamic_api_key = get_secret("DEEPSEEK_API_KEY")
|
||||
elif custom_llm_provider == "fireworks_ai":
|
||||
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
|
||||
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1
|
||||
if not model.startswith("accounts/fireworks/models"):
|
||||
model = f"accounts/fireworks/models/{model}"
|
||||
api_base = "https://api.fireworks.ai/inference/v1"
|
||||
|
@ -6303,6 +6345,9 @@ def get_llm_provider(
|
|||
elif endpoint == "api.groq.com/openai/v1":
|
||||
custom_llm_provider = "groq"
|
||||
dynamic_api_key = get_secret("GROQ_API_KEY")
|
||||
elif endpoint == "api.deepseek.com/v1":
|
||||
custom_llm_provider = "deepseek"
|
||||
dynamic_api_key = get_secret("DEEPSEEK_API_KEY")
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
|
||||
# check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.)
|
||||
|
@ -6901,6 +6946,11 @@ def validate_environment(model: Optional[str] = None) -> dict:
|
|||
keys_in_environment = True
|
||||
else:
|
||||
missing_keys.append("GROQ_API_KEY")
|
||||
elif custom_llm_provider == "deepseek":
|
||||
if "DEEPSEEK_API_KEY" in os.environ:
|
||||
keys_in_environment = True
|
||||
else:
|
||||
missing_keys.append("DEEPSEEK_API_KEY")
|
||||
elif custom_llm_provider == "mistral":
|
||||
if "MISTRAL_API_KEY" in os.environ:
|
||||
keys_in_environment = True
|
||||
|
|
|
@ -739,6 +739,24 @@
|
|||
"litellm_provider": "mistral",
|
||||
"mode": "embedding"
|
||||
},
|
||||
"deepseek-chat": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 32000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000014,
|
||||
"output_cost_per_token": 0.00000028,
|
||||
"litellm_provider": "deepseek",
|
||||
"mode": "chat"
|
||||
},
|
||||
"deepseek-coder": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 16000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000014,
|
||||
"output_cost_per_token": 0.00000028,
|
||||
"litellm_provider": "deepseek",
|
||||
"mode": "chat"
|
||||
},
|
||||
"groq/llama2-70b-4096": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 4096,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "litellm"
|
||||
version = "1.36.1"
|
||||
version = "1.36.2"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
authors = ["BerriAI"]
|
||||
license = "MIT"
|
||||
|
@ -80,7 +80,7 @@ requires = ["poetry-core", "wheel"]
|
|||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.commitizen]
|
||||
version = "1.36.1"
|
||||
version = "1.36.2"
|
||||
version_files = [
|
||||
"pyproject.toml:^version"
|
||||
]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue