forked from phoenix/litellm-mirror
Merge branch 'main' into litellm_ui_fixes_6
This commit is contained in:
commit
0e709fdc21
26 changed files with 631 additions and 185 deletions
|
@ -226,6 +226,7 @@ curl 'http://0.0.0.0:4000/key/generate' \
|
||||||
| [deepinfra](https://docs.litellm.ai/docs/providers/deepinfra) | ✅ | ✅ | ✅ | ✅ |
|
| [deepinfra](https://docs.litellm.ai/docs/providers/deepinfra) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [perplexity-ai](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | ✅ |
|
| [perplexity-ai](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [Groq AI](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | ✅ |
|
| [Groq AI](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | ✅ |
|
||||||
|
| [Deepseek](https://docs.litellm.ai/docs/providers/deepseek) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [anyscale](https://docs.litellm.ai/docs/providers/anyscale) | ✅ | ✅ | ✅ | ✅ |
|
| [anyscale](https://docs.litellm.ai/docs/providers/anyscale) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [IBM - watsonx.ai](https://docs.litellm.ai/docs/providers/watsonx) | ✅ | ✅ | ✅ | ✅ | ✅
|
| [IBM - watsonx.ai](https://docs.litellm.ai/docs/providers/watsonx) | ✅ | ✅ | ✅ | ✅ | ✅
|
||||||
| [voyage ai](https://docs.litellm.ai/docs/providers/voyage) | | | | | ✅ |
|
| [voyage ai](https://docs.litellm.ai/docs/providers/voyage) | | | | | ✅ |
|
||||||
|
|
54
docs/my-website/docs/providers/deepseek.md
Normal file
54
docs/my-website/docs/providers/deepseek.md
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
# Deepseek
|
||||||
|
https://deepseek.com/
|
||||||
|
|
||||||
|
**We support ALL Deepseek models, just set `deepseek/` as a prefix when sending completion requests**
|
||||||
|
|
||||||
|
## API Key
|
||||||
|
```python
|
||||||
|
# env variable
|
||||||
|
os.environ['DEEPSEEK_API_KEY']
|
||||||
|
```
|
||||||
|
|
||||||
|
## Sample Usage
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ['DEEPSEEK_API_KEY'] = ""
|
||||||
|
response = completion(
|
||||||
|
model="deepseek/deepseek-chat",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "hello from litellm"}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Sample Usage - Streaming
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ['DEEPSEEK_API_KEY'] = ""
|
||||||
|
response = completion(
|
||||||
|
model="deepseek/deepseek-chat",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "hello from litellm"}
|
||||||
|
],
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Supported Models - ALL Deepseek Models Supported!
|
||||||
|
We support ALL Deepseek models, just set `deepseek/` as a prefix when sending completion requests
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| deepseek-chat | `completion(model="deepseek/deepseek-chat", messages)` |
|
||||||
|
| deepseek-coder | `completion(model="deepseek/deepseek-chat", messages)` |
|
||||||
|
|
||||||
|
|
|
@ -468,7 +468,7 @@ asyncio.run(router_acompletion())
|
||||||
```
|
```
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="lowest-cost" label="Lowest Cost Routing">
|
<TabItem value="lowest-cost" label="Lowest Cost Routing (Async)">
|
||||||
|
|
||||||
Picks a deployment based on the lowest cost
|
Picks a deployment based on the lowest cost
|
||||||
|
|
||||||
|
@ -1086,6 +1086,46 @@ async def test_acompletion_caching_on_router_caching_groups():
|
||||||
asyncio.run(test_acompletion_caching_on_router_caching_groups())
|
asyncio.run(test_acompletion_caching_on_router_caching_groups())
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Alerting 🚨
|
||||||
|
|
||||||
|
Send alerts to slack / your webhook url for the following events
|
||||||
|
- LLM API Exceptions
|
||||||
|
- Slow LLM Responses
|
||||||
|
|
||||||
|
Get a slack webhook url from https://api.slack.com/messaging/webhooks
|
||||||
|
|
||||||
|
#### Usage
|
||||||
|
Initialize an `AlertingConfig` and pass it to `litellm.Router`. The following code will trigger an alert because `api_key=bad-key` which is invalid
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm.router import AlertingConfig
|
||||||
|
import litellm
|
||||||
|
import os
|
||||||
|
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": "bad_key",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
alerting_config= AlertingConfig(
|
||||||
|
alerting_threshold=10, # threshold for slow / hanging llm responses (in seconds). Defaults to 300 seconds
|
||||||
|
webhook_url= os.getenv("SLACK_WEBHOOK_URL") # webhook you want to send alerts to
|
||||||
|
),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await router.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
## Track cost for Azure Deployments
|
## Track cost for Azure Deployments
|
||||||
|
|
||||||
**Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking
|
**Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking
|
||||||
|
|
|
@ -134,6 +134,7 @@ const sidebars = {
|
||||||
"providers/ollama",
|
"providers/ollama",
|
||||||
"providers/perplexity",
|
"providers/perplexity",
|
||||||
"providers/groq",
|
"providers/groq",
|
||||||
|
"providers/deepseek",
|
||||||
"providers/fireworks_ai",
|
"providers/fireworks_ai",
|
||||||
"providers/vllm",
|
"providers/vllm",
|
||||||
"providers/xinference",
|
"providers/xinference",
|
||||||
|
|
|
@ -361,6 +361,7 @@ openai_compatible_endpoints: List = [
|
||||||
"api.deepinfra.com/v1/openai",
|
"api.deepinfra.com/v1/openai",
|
||||||
"api.mistral.ai/v1",
|
"api.mistral.ai/v1",
|
||||||
"api.groq.com/openai/v1",
|
"api.groq.com/openai/v1",
|
||||||
|
"api.deepseek.com/v1",
|
||||||
"api.together.xyz/v1",
|
"api.together.xyz/v1",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -369,6 +370,7 @@ openai_compatible_providers: List = [
|
||||||
"anyscale",
|
"anyscale",
|
||||||
"mistral",
|
"mistral",
|
||||||
"groq",
|
"groq",
|
||||||
|
"deepseek",
|
||||||
"deepinfra",
|
"deepinfra",
|
||||||
"perplexity",
|
"perplexity",
|
||||||
"xinference",
|
"xinference",
|
||||||
|
@ -523,6 +525,7 @@ provider_list: List = [
|
||||||
"anyscale",
|
"anyscale",
|
||||||
"mistral",
|
"mistral",
|
||||||
"groq",
|
"groq",
|
||||||
|
"deepseek",
|
||||||
"maritalk",
|
"maritalk",
|
||||||
"voyage",
|
"voyage",
|
||||||
"cloudflare",
|
"cloudflare",
|
||||||
|
|
|
@ -276,7 +276,6 @@ class LangFuseLogger:
|
||||||
metadata_tags = metadata.pop("tags", [])
|
metadata_tags = metadata.pop("tags", [])
|
||||||
tags = metadata_tags
|
tags = metadata_tags
|
||||||
|
|
||||||
|
|
||||||
# Clean Metadata before logging - never log raw metadata
|
# Clean Metadata before logging - never log raw metadata
|
||||||
# the raw metadata can contain circular references which leads to infinite recursion
|
# the raw metadata can contain circular references which leads to infinite recursion
|
||||||
# we clean out all extra litellm metadata params before logging
|
# we clean out all extra litellm metadata params before logging
|
||||||
|
@ -303,7 +302,6 @@ class LangFuseLogger:
|
||||||
else:
|
else:
|
||||||
clean_metadata[key] = value
|
clean_metadata[key] = value
|
||||||
|
|
||||||
|
|
||||||
session_id = clean_metadata.pop("session_id", None)
|
session_id = clean_metadata.pop("session_id", None)
|
||||||
trace_name = clean_metadata.pop("trace_name", None)
|
trace_name = clean_metadata.pop("trace_name", None)
|
||||||
trace_id = clean_metadata.pop("trace_id", None)
|
trace_id = clean_metadata.pop("trace_id", None)
|
||||||
|
@ -322,13 +320,16 @@ class LangFuseLogger:
|
||||||
for metadata_param_key in update_trace_keys:
|
for metadata_param_key in update_trace_keys:
|
||||||
trace_param_key = metadata_param_key.replace("trace_", "")
|
trace_param_key = metadata_param_key.replace("trace_", "")
|
||||||
if trace_param_key not in trace_params:
|
if trace_param_key not in trace_params:
|
||||||
updated_trace_value = clean_metadata.pop(metadata_param_key, None)
|
updated_trace_value = clean_metadata.pop(
|
||||||
|
metadata_param_key, None
|
||||||
|
)
|
||||||
if updated_trace_value is not None:
|
if updated_trace_value is not None:
|
||||||
trace_params[trace_param_key] = updated_trace_value
|
trace_params[trace_param_key] = updated_trace_value
|
||||||
|
|
||||||
|
|
||||||
# Pop the trace specific keys that would have been popped if there were a new trace
|
# Pop the trace specific keys that would have been popped if there were a new trace
|
||||||
for key in list(filter(lambda key: key.startswith("trace_"), clean_metadata.keys())):
|
for key in list(
|
||||||
|
filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
|
||||||
|
):
|
||||||
clean_metadata.pop(key, None)
|
clean_metadata.pop(key, None)
|
||||||
|
|
||||||
# Special keys that are found in the function arguments and not the metadata
|
# Special keys that are found in the function arguments and not the metadata
|
||||||
|
@ -342,10 +343,16 @@ class LangFuseLogger:
|
||||||
"name": trace_name,
|
"name": trace_name,
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"input": input,
|
"input": input,
|
||||||
"version": clean_metadata.pop("trace_version", clean_metadata.get("version", None)), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence
|
"version": clean_metadata.pop(
|
||||||
|
"trace_version", clean_metadata.get("version", None)
|
||||||
|
), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence
|
||||||
}
|
}
|
||||||
for key in list(filter(lambda key: key.startswith("trace_"), clean_metadata.keys())):
|
for key in list(
|
||||||
trace_params[key.replace("trace_", "")] = clean_metadata.pop(key, None)
|
filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
|
||||||
|
):
|
||||||
|
trace_params[key.replace("trace_", "")] = clean_metadata.pop(
|
||||||
|
key, None
|
||||||
|
)
|
||||||
|
|
||||||
if level == "ERROR":
|
if level == "ERROR":
|
||||||
trace_params["status_message"] = output
|
trace_params["status_message"] = output
|
||||||
|
|
|
@ -68,11 +68,15 @@ class SlackAlertingCacheKeys(Enum):
|
||||||
|
|
||||||
|
|
||||||
class SlackAlerting(CustomLogger):
|
class SlackAlerting(CustomLogger):
|
||||||
|
"""
|
||||||
|
Class for sending Slack Alerts
|
||||||
|
"""
|
||||||
|
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
internal_usage_cache: Optional[DualCache] = None,
|
internal_usage_cache: Optional[DualCache] = None,
|
||||||
alerting_threshold: float = 300,
|
alerting_threshold: float = 300, # threshold for slow / hanging llm responses (in seconds)
|
||||||
alerting: Optional[List] = [],
|
alerting: Optional[List] = [],
|
||||||
alert_types: Optional[
|
alert_types: Optional[
|
||||||
List[
|
List[
|
||||||
|
@ -97,6 +101,7 @@ class SlackAlerting(CustomLogger):
|
||||||
Dict
|
Dict
|
||||||
] = None, # if user wants to separate alerts to diff channels
|
] = None, # if user wants to separate alerts to diff channels
|
||||||
alerting_args={},
|
alerting_args={},
|
||||||
|
default_webhook_url: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.alerting_threshold = alerting_threshold
|
self.alerting_threshold = alerting_threshold
|
||||||
self.alerting = alerting
|
self.alerting = alerting
|
||||||
|
@ -106,6 +111,7 @@ class SlackAlerting(CustomLogger):
|
||||||
self.alert_to_webhook_url = alert_to_webhook_url
|
self.alert_to_webhook_url = alert_to_webhook_url
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
self.alerting_args = SlackAlertingArgs(**alerting_args)
|
self.alerting_args = SlackAlertingArgs(**alerting_args)
|
||||||
|
self.default_webhook_url = default_webhook_url
|
||||||
|
|
||||||
def update_values(
|
def update_values(
|
||||||
self,
|
self,
|
||||||
|
@ -149,16 +155,21 @@ class SlackAlerting(CustomLogger):
|
||||||
|
|
||||||
def _add_langfuse_trace_id_to_alert(
|
def _add_langfuse_trace_id_to_alert(
|
||||||
self,
|
self,
|
||||||
request_info: str,
|
|
||||||
request_data: Optional[dict] = None,
|
request_data: Optional[dict] = None,
|
||||||
kwargs: Optional[dict] = None,
|
) -> Optional[str]:
|
||||||
type: Literal["hanging_request", "slow_response"] = "hanging_request",
|
"""
|
||||||
start_time: Optional[datetime.datetime] = None,
|
Returns langfuse trace url
|
||||||
end_time: Optional[datetime.datetime] = None,
|
"""
|
||||||
):
|
|
||||||
# do nothing for now
|
# do nothing for now
|
||||||
pass
|
if (
|
||||||
return request_info
|
request_data is not None
|
||||||
|
and request_data.get("metadata", {}).get("trace_id", None) is not None
|
||||||
|
):
|
||||||
|
trace_id = request_data["metadata"]["trace_id"]
|
||||||
|
if litellm.utils.langFuseLogger is not None:
|
||||||
|
base_url = litellm.utils.langFuseLogger.Langfuse.base_url
|
||||||
|
return f"{base_url}/trace/{trace_id}"
|
||||||
|
return None
|
||||||
|
|
||||||
def _response_taking_too_long_callback_helper(
|
def _response_taking_too_long_callback_helper(
|
||||||
self,
|
self,
|
||||||
|
@ -302,7 +313,7 @@ class SlackAlerting(CustomLogger):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
async def send_daily_reports(self, router: litellm.Router) -> bool:
|
async def send_daily_reports(self, router) -> bool:
|
||||||
"""
|
"""
|
||||||
Send a daily report on:
|
Send a daily report on:
|
||||||
- Top 5 deployments with most failed requests
|
- Top 5 deployments with most failed requests
|
||||||
|
@ -501,14 +512,13 @@ class SlackAlerting(CustomLogger):
|
||||||
)
|
)
|
||||||
|
|
||||||
if "langfuse" in litellm.success_callback:
|
if "langfuse" in litellm.success_callback:
|
||||||
request_info = self._add_langfuse_trace_id_to_alert(
|
langfuse_url = self._add_langfuse_trace_id_to_alert(
|
||||||
request_info=request_info,
|
|
||||||
request_data=request_data,
|
request_data=request_data,
|
||||||
type="hanging_request",
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if langfuse_url is not None:
|
||||||
|
request_info += "\n🪢 Langfuse Trace: {}".format(langfuse_url)
|
||||||
|
|
||||||
# add deployment latencies to alert
|
# add deployment latencies to alert
|
||||||
_deployment_latency_map = self._get_deployment_latencies_to_alert(
|
_deployment_latency_map = self._get_deployment_latencies_to_alert(
|
||||||
metadata=request_data.get("metadata", {})
|
metadata=request_data.get("metadata", {})
|
||||||
|
@ -701,6 +711,7 @@ Model Info:
|
||||||
"daily_reports",
|
"daily_reports",
|
||||||
"new_model_added",
|
"new_model_added",
|
||||||
],
|
],
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
|
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
|
||||||
|
@ -731,6 +742,10 @@ Model Info:
|
||||||
formatted_message = (
|
formatted_message = (
|
||||||
f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}"
|
f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
formatted_message += f"\n\n{key}: `{value}`\n\n"
|
||||||
if _proxy_base_url is not None:
|
if _proxy_base_url is not None:
|
||||||
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
|
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
|
||||||
|
|
||||||
|
@ -740,6 +755,8 @@ Model Info:
|
||||||
and alert_type in self.alert_to_webhook_url
|
and alert_type in self.alert_to_webhook_url
|
||||||
):
|
):
|
||||||
slack_webhook_url = self.alert_to_webhook_url[alert_type]
|
slack_webhook_url = self.alert_to_webhook_url[alert_type]
|
||||||
|
elif self.default_webhook_url is not None:
|
||||||
|
slack_webhook_url = self.default_webhook_url
|
||||||
else:
|
else:
|
||||||
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
|
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
|
||||||
|
|
||||||
|
@ -796,8 +813,16 @@ Model Info:
|
||||||
updated_at=litellm.utils.get_utc_datetime(),
|
updated_at=litellm.utils.get_utc_datetime(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if "llm_exceptions" in self.alert_types:
|
||||||
|
original_exception = kwargs.get("exception", None)
|
||||||
|
|
||||||
async def _run_scheduler_helper(self, llm_router: litellm.Router) -> bool:
|
await self.send_alert(
|
||||||
|
message="LLM API Failure - " + str(original_exception),
|
||||||
|
level="High",
|
||||||
|
alert_type="llm_exceptions",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _run_scheduler_helper(self, llm_router) -> bool:
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
- True -> report sent
|
- True -> report sent
|
||||||
|
@ -839,7 +864,7 @@ Model Info:
|
||||||
|
|
||||||
return report_sent_bool
|
return report_sent_bool
|
||||||
|
|
||||||
async def _run_scheduled_daily_report(self, llm_router: Optional[litellm.Router]):
|
async def _run_scheduled_daily_report(self, llm_router: Optional[Any] = None):
|
||||||
"""
|
"""
|
||||||
If 'daily_reports' enabled
|
If 'daily_reports' enabled
|
||||||
|
|
||||||
|
|
|
@ -474,3 +474,23 @@ async def ollama_aembeddings(
|
||||||
"total_tokens": total_input_tokens,
|
"total_tokens": total_input_tokens,
|
||||||
}
|
}
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
def ollama_embeddings(
|
||||||
|
api_base: str,
|
||||||
|
model: str,
|
||||||
|
prompts: list,
|
||||||
|
optional_params=None,
|
||||||
|
logging_obj=None,
|
||||||
|
model_response=None,
|
||||||
|
encoding=None,
|
||||||
|
):
|
||||||
|
return asyncio.run(
|
||||||
|
ollama_aembeddings(
|
||||||
|
api_base,
|
||||||
|
model,
|
||||||
|
prompts,
|
||||||
|
optional_params,
|
||||||
|
logging_obj,
|
||||||
|
model_response,
|
||||||
|
encoding)
|
||||||
|
)
|
||||||
|
|
|
@ -305,6 +305,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "deepinfra"
|
or custom_llm_provider == "deepinfra"
|
||||||
or custom_llm_provider == "perplexity"
|
or custom_llm_provider == "perplexity"
|
||||||
or custom_llm_provider == "groq"
|
or custom_llm_provider == "groq"
|
||||||
|
or custom_llm_provider == "deepseek"
|
||||||
or custom_llm_provider == "text-completion-openai"
|
or custom_llm_provider == "text-completion-openai"
|
||||||
or custom_llm_provider == "huggingface"
|
or custom_llm_provider == "huggingface"
|
||||||
or custom_llm_provider == "ollama"
|
or custom_llm_provider == "ollama"
|
||||||
|
@ -982,6 +983,7 @@ def completion(
|
||||||
or custom_llm_provider == "deepinfra"
|
or custom_llm_provider == "deepinfra"
|
||||||
or custom_llm_provider == "perplexity"
|
or custom_llm_provider == "perplexity"
|
||||||
or custom_llm_provider == "groq"
|
or custom_llm_provider == "groq"
|
||||||
|
or custom_llm_provider == "deepseek"
|
||||||
or custom_llm_provider == "anyscale"
|
or custom_llm_provider == "anyscale"
|
||||||
or custom_llm_provider == "mistral"
|
or custom_llm_provider == "mistral"
|
||||||
or custom_llm_provider == "openai"
|
or custom_llm_provider == "openai"
|
||||||
|
@ -2565,6 +2567,7 @@ async def aembedding(*args, **kwargs):
|
||||||
or custom_llm_provider == "deepinfra"
|
or custom_llm_provider == "deepinfra"
|
||||||
or custom_llm_provider == "perplexity"
|
or custom_llm_provider == "perplexity"
|
||||||
or custom_llm_provider == "groq"
|
or custom_llm_provider == "groq"
|
||||||
|
or custom_llm_provider == "deepseek"
|
||||||
or custom_llm_provider == "fireworks_ai"
|
or custom_llm_provider == "fireworks_ai"
|
||||||
or custom_llm_provider == "ollama"
|
or custom_llm_provider == "ollama"
|
||||||
or custom_llm_provider == "vertex_ai"
|
or custom_llm_provider == "vertex_ai"
|
||||||
|
@ -2947,8 +2950,8 @@ def embedding(
|
||||||
model=model, # type: ignore
|
model=model, # type: ignore
|
||||||
llm_provider="ollama", # type: ignore
|
llm_provider="ollama", # type: ignore
|
||||||
)
|
)
|
||||||
if aembedding:
|
ollama_embeddings_fn = ollama.ollama_aembeddings if aembedding else ollama.ollama_embeddings
|
||||||
response = ollama.ollama_aembeddings(
|
response = ollama_embeddings_fn(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
model=model,
|
model=model,
|
||||||
prompts=input,
|
prompts=input,
|
||||||
|
@ -3085,6 +3088,7 @@ async def atext_completion(*args, **kwargs):
|
||||||
or custom_llm_provider == "deepinfra"
|
or custom_llm_provider == "deepinfra"
|
||||||
or custom_llm_provider == "perplexity"
|
or custom_llm_provider == "perplexity"
|
||||||
or custom_llm_provider == "groq"
|
or custom_llm_provider == "groq"
|
||||||
|
or custom_llm_provider == "deepseek"
|
||||||
or custom_llm_provider == "fireworks_ai"
|
or custom_llm_provider == "fireworks_ai"
|
||||||
or custom_llm_provider == "text-completion-openai"
|
or custom_llm_provider == "text-completion-openai"
|
||||||
or custom_llm_provider == "huggingface"
|
or custom_llm_provider == "huggingface"
|
||||||
|
|
|
@ -739,6 +739,24 @@
|
||||||
"litellm_provider": "mistral",
|
"litellm_provider": "mistral",
|
||||||
"mode": "embedding"
|
"mode": "embedding"
|
||||||
},
|
},
|
||||||
|
"deepseek-chat": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 32000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.00000014,
|
||||||
|
"output_cost_per_token": 0.00000028,
|
||||||
|
"litellm_provider": "deepseek",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"deepseek-coder": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 16000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.00000014,
|
||||||
|
"output_cost_per_token": 0.00000028,
|
||||||
|
"litellm_provider": "deepseek",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
"groq/llama2-70b-4096": {
|
"groq/llama2-70b-4096": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 4096,
|
"max_input_tokens": 4096,
|
||||||
|
|
|
@ -14,6 +14,9 @@ model_list:
|
||||||
api_key: my-fake-key-3
|
api_key: my-fake-key-3
|
||||||
model: openai/my-fake-model-3
|
model: openai/my-fake-model-3
|
||||||
model_name: fake-openai-endpoint
|
model_name: fake-openai-endpoint
|
||||||
|
- model_name: gpt-4
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
router_settings:
|
router_settings:
|
||||||
num_retries: 0
|
num_retries: 0
|
||||||
enable_pre_call_checks: true
|
enable_pre_call_checks: true
|
||||||
|
@ -25,7 +28,7 @@ router_settings:
|
||||||
routing_strategy: "latency-based-routing"
|
routing_strategy: "latency-based-routing"
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
success_callback: ["openmeter"]
|
success_callback: ["langfuse"]
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
alerting: ["slack"]
|
alerting: ["slack"]
|
||||||
|
|
|
@ -2531,6 +2531,7 @@ class ProxyConfig:
|
||||||
if "db_model" in model.model_info and model.model_info["db_model"] == False:
|
if "db_model" in model.model_info and model.model_info["db_model"] == False:
|
||||||
model.model_info["db_model"] = db_model
|
model.model_info["db_model"] = db_model
|
||||||
_model_info = RouterModelInfo(**model.model_info)
|
_model_info = RouterModelInfo(**model.model_info)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
_model_info = RouterModelInfo(id=model.model_id, db_model=db_model)
|
_model_info = RouterModelInfo(id=model.model_id, db_model=db_model)
|
||||||
return _model_info
|
return _model_info
|
||||||
|
@ -3175,7 +3176,9 @@ def data_generator(response):
|
||||||
yield f"data: {json.dumps(chunk)}\n\n"
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
async def async_data_generator(response, user_api_key_dict):
|
async def async_data_generator(
|
||||||
|
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
||||||
|
):
|
||||||
verbose_proxy_logger.debug("inside generator")
|
verbose_proxy_logger.debug("inside generator")
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -3192,7 +3195,9 @@ async def async_data_generator(response, user_api_key_dict):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
original_exception=e,
|
||||||
|
request_data=request_data,
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
||||||
|
@ -3217,8 +3222,14 @@ async def async_data_generator(response, user_api_key_dict):
|
||||||
yield f"data: {error_returned}\n\n"
|
yield f"data: {error_returned}\n\n"
|
||||||
|
|
||||||
|
|
||||||
def select_data_generator(response, user_api_key_dict):
|
def select_data_generator(
|
||||||
return async_data_generator(response=response, user_api_key_dict=user_api_key_dict)
|
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
||||||
|
):
|
||||||
|
return async_data_generator(
|
||||||
|
response=response,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
request_data=request_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_litellm_model_info(model: dict = {}):
|
def get_litellm_model_info(model: dict = {}):
|
||||||
|
@ -3513,9 +3524,8 @@ async def chat_completion(
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
||||||
try:
|
|
||||||
# async with llm_router.sem
|
|
||||||
data = {}
|
data = {}
|
||||||
|
try:
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
body_str = body.decode()
|
body_str = body.decode()
|
||||||
try:
|
try:
|
||||||
|
@ -3706,7 +3716,9 @@ async def chat_completion(
|
||||||
"x-litellm-model-api-base": api_base,
|
"x-litellm-model-api-base": api_base,
|
||||||
}
|
}
|
||||||
selected_data_generator = select_data_generator(
|
selected_data_generator = select_data_generator(
|
||||||
response=response, user_api_key_dict=user_api_key_dict
|
response=response,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
selected_data_generator,
|
selected_data_generator,
|
||||||
|
@ -3728,7 +3740,7 @@ async def chat_completion(
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
data["litellm_status"] = "fail" # used for alerting
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
||||||
|
@ -3890,7 +3902,9 @@ async def completion(
|
||||||
"x-litellm-model-id": model_id,
|
"x-litellm-model-id": model_id,
|
||||||
}
|
}
|
||||||
selected_data_generator = select_data_generator(
|
selected_data_generator = select_data_generator(
|
||||||
response=response, user_api_key_dict=user_api_key_dict
|
response=response,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
|
@ -3943,6 +3957,7 @@ async def embeddings(
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
global proxy_logging_obj
|
global proxy_logging_obj
|
||||||
|
data: Any = {}
|
||||||
try:
|
try:
|
||||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
|
@ -4088,7 +4103,7 @@ async def embeddings(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
data["litellm_status"] = "fail" # used for alerting
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
if isinstance(e, HTTPException):
|
if isinstance(e, HTTPException):
|
||||||
|
@ -4125,6 +4140,7 @@ async def image_generation(
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
global proxy_logging_obj
|
global proxy_logging_obj
|
||||||
|
data = {}
|
||||||
try:
|
try:
|
||||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
|
@ -4244,7 +4260,7 @@ async def image_generation(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
data["litellm_status"] = "fail" # used for alerting
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
if isinstance(e, HTTPException):
|
if isinstance(e, HTTPException):
|
||||||
|
@ -4285,10 +4301,11 @@ async def audio_transcriptions(
|
||||||
https://platform.openai.com/docs/api-reference/audio/createTranscription?lang=curl
|
https://platform.openai.com/docs/api-reference/audio/createTranscription?lang=curl
|
||||||
"""
|
"""
|
||||||
global proxy_logging_obj
|
global proxy_logging_obj
|
||||||
|
data: Dict = {}
|
||||||
try:
|
try:
|
||||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||||
form_data = await request.form()
|
form_data = await request.form()
|
||||||
data: Dict = {key: value for key, value in form_data.items() if key != "file"}
|
data = {key: value for key, value in form_data.items() if key != "file"}
|
||||||
|
|
||||||
# Include original request and headers in the data
|
# Include original request and headers in the data
|
||||||
data["proxy_server_request"] = { # type: ignore
|
data["proxy_server_request"] = { # type: ignore
|
||||||
|
@ -4423,7 +4440,7 @@ async def audio_transcriptions(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
data["litellm_status"] = "fail" # used for alerting
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
if isinstance(e, HTTPException):
|
if isinstance(e, HTTPException):
|
||||||
|
@ -4472,6 +4489,7 @@ async def moderations(
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
global proxy_logging_obj
|
global proxy_logging_obj
|
||||||
|
data: Dict = {}
|
||||||
try:
|
try:
|
||||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
|
@ -4585,7 +4603,7 @@ async def moderations(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
data["litellm_status"] = "fail" # used for alerting
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
if isinstance(e, HTTPException):
|
if isinstance(e, HTTPException):
|
||||||
|
@ -8048,8 +8066,8 @@ async def async_queue_request(
|
||||||
|
|
||||||
Now using a FastAPI background task + /chat/completions compatible endpoint
|
Now using a FastAPI background task + /chat/completions compatible endpoint
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
data = {}
|
data = {}
|
||||||
|
try:
|
||||||
data = await request.json() # type: ignore
|
data = await request.json() # type: ignore
|
||||||
|
|
||||||
# Include original request and headers in the data
|
# Include original request and headers in the data
|
||||||
|
@ -8114,7 +8132,9 @@ async def async_queue_request(
|
||||||
): # use generate_responses to stream responses
|
): # use generate_responses to stream responses
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
async_data_generator(
|
async_data_generator(
|
||||||
user_api_key_dict=user_api_key_dict, response=response
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
response=response,
|
||||||
|
request_data=data,
|
||||||
),
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
)
|
)
|
||||||
|
@ -8122,7 +8142,7 @@ async def async_queue_request(
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
if isinstance(e, HTTPException):
|
if isinstance(e, HTTPException):
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
|
|
|
@ -302,6 +302,7 @@ class ProxyLogging:
|
||||||
"budget_alerts",
|
"budget_alerts",
|
||||||
"db_exceptions",
|
"db_exceptions",
|
||||||
],
|
],
|
||||||
|
request_data: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
|
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
|
||||||
|
@ -331,10 +332,19 @@ class ProxyLogging:
|
||||||
if _proxy_base_url is not None:
|
if _proxy_base_url is not None:
|
||||||
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
|
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
|
||||||
|
|
||||||
|
extra_kwargs = {}
|
||||||
|
if request_data is not None:
|
||||||
|
_url = self.slack_alerting_instance._add_langfuse_trace_id_to_alert(
|
||||||
|
request_data=request_data
|
||||||
|
)
|
||||||
|
if _url is not None:
|
||||||
|
extra_kwargs["🪢 Langfuse Trace"] = _url
|
||||||
|
formatted_message += "\n\n🪢 Langfuse Trace: {}".format(_url)
|
||||||
|
|
||||||
for client in self.alerting:
|
for client in self.alerting:
|
||||||
if client == "slack":
|
if client == "slack":
|
||||||
await self.slack_alerting_instance.send_alert(
|
await self.slack_alerting_instance.send_alert(
|
||||||
message=message, level=level, alert_type=alert_type
|
message=message, level=level, alert_type=alert_type, **extra_kwargs
|
||||||
)
|
)
|
||||||
elif client == "sentry":
|
elif client == "sentry":
|
||||||
if litellm.utils.sentry_sdk_instance is not None:
|
if litellm.utils.sentry_sdk_instance is not None:
|
||||||
|
@ -369,6 +379,7 @@ class ProxyLogging:
|
||||||
message=f"DB read/write call failed: {error_message}",
|
message=f"DB read/write call failed: {error_message}",
|
||||||
level="High",
|
level="High",
|
||||||
alert_type="db_exceptions",
|
alert_type="db_exceptions",
|
||||||
|
request_data={},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -384,7 +395,10 @@ class ProxyLogging:
|
||||||
litellm.utils.capture_exception(error=original_exception)
|
litellm.utils.capture_exception(error=original_exception)
|
||||||
|
|
||||||
async def post_call_failure_hook(
|
async def post_call_failure_hook(
|
||||||
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
|
self,
|
||||||
|
original_exception: Exception,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
request_data: dict,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body.
|
Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body.
|
||||||
|
@ -409,6 +423,7 @@ class ProxyLogging:
|
||||||
message=f"LLM API call failed: {str(original_exception)}",
|
message=f"LLM API call failed: {str(original_exception)}",
|
||||||
level="High",
|
level="High",
|
||||||
alert_type="llm_exceptions",
|
alert_type="llm_exceptions",
|
||||||
|
request_data=request_data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,7 @@ from litellm.types.router import (
|
||||||
updateDeployment,
|
updateDeployment,
|
||||||
updateLiteLLMParams,
|
updateLiteLLMParams,
|
||||||
RetryPolicy,
|
RetryPolicy,
|
||||||
|
AlertingConfig,
|
||||||
)
|
)
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
||||||
|
@ -103,6 +104,7 @@ class Router:
|
||||||
] = "simple-shuffle",
|
] = "simple-shuffle",
|
||||||
routing_strategy_args: dict = {}, # just for latency-based routing
|
routing_strategy_args: dict = {}, # just for latency-based routing
|
||||||
semaphore: Optional[asyncio.Semaphore] = None,
|
semaphore: Optional[asyncio.Semaphore] = None,
|
||||||
|
alerting_config: Optional[AlertingConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the Router class with the given parameters for caching, reliability, and routing strategy.
|
Initialize the Router class with the given parameters for caching, reliability, and routing strategy.
|
||||||
|
@ -131,7 +133,7 @@ class Router:
|
||||||
cooldown_time (float): Time to cooldown a deployment after failure in seconds. Defaults to 1.
|
cooldown_time (float): Time to cooldown a deployment after failure in seconds. Defaults to 1.
|
||||||
routing_strategy (Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing", "cost-based-routing"]): Routing strategy. Defaults to "simple-shuffle".
|
routing_strategy (Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing", "cost-based-routing"]): Routing strategy. Defaults to "simple-shuffle".
|
||||||
routing_strategy_args (dict): Additional args for latency-based routing. Defaults to {}.
|
routing_strategy_args (dict): Additional args for latency-based routing. Defaults to {}.
|
||||||
|
alerting_config (AlertingConfig): Slack alerting configuration. Defaults to None.
|
||||||
Returns:
|
Returns:
|
||||||
Router: An instance of the litellm.Router class.
|
Router: An instance of the litellm.Router class.
|
||||||
|
|
||||||
|
@ -316,6 +318,9 @@ class Router:
|
||||||
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
|
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
|
||||||
model_group_retry_policy
|
model_group_retry_policy
|
||||||
)
|
)
|
||||||
|
self.alerting_config: Optional[AlertingConfig] = alerting_config
|
||||||
|
if self.alerting_config is not None:
|
||||||
|
self._initialize_alerting()
|
||||||
|
|
||||||
def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
|
def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
|
||||||
if routing_strategy == "least-busy":
|
if routing_strategy == "least-busy":
|
||||||
|
@ -3000,6 +3005,7 @@ class Router:
|
||||||
if (
|
if (
|
||||||
self.routing_strategy != "usage-based-routing-v2"
|
self.routing_strategy != "usage-based-routing-v2"
|
||||||
and self.routing_strategy != "simple-shuffle"
|
and self.routing_strategy != "simple-shuffle"
|
||||||
|
and self.routing_strategy != "cost-based-routing"
|
||||||
): # prevent regressions for other routing strategies, that don't have async get available deployments implemented.
|
): # prevent regressions for other routing strategies, that don't have async get available deployments implemented.
|
||||||
return self.get_available_deployment(
|
return self.get_available_deployment(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -3056,6 +3062,16 @@ class Router:
|
||||||
messages=messages,
|
messages=messages,
|
||||||
input=input,
|
input=input,
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
self.routing_strategy == "cost-based-routing"
|
||||||
|
and self.lowestcost_logger is not None
|
||||||
|
):
|
||||||
|
deployment = await self.lowestcost_logger.async_get_available_deployments(
|
||||||
|
model_group=model,
|
||||||
|
healthy_deployments=healthy_deployments,
|
||||||
|
messages=messages,
|
||||||
|
input=input,
|
||||||
|
)
|
||||||
elif self.routing_strategy == "simple-shuffle":
|
elif self.routing_strategy == "simple-shuffle":
|
||||||
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
|
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
|
||||||
############## Check if we can do a RPM/TPM based weighted pick #################
|
############## Check if we can do a RPM/TPM based weighted pick #################
|
||||||
|
@ -3226,15 +3242,6 @@ class Router:
|
||||||
messages=messages,
|
messages=messages,
|
||||||
input=input,
|
input=input,
|
||||||
)
|
)
|
||||||
elif (
|
|
||||||
self.routing_strategy == "cost-based-routing"
|
|
||||||
and self.lowestcost_logger is not None
|
|
||||||
):
|
|
||||||
deployment = self.lowestcost_logger.get_available_deployments(
|
|
||||||
model_group=model,
|
|
||||||
healthy_deployments=healthy_deployments,
|
|
||||||
request_kwargs=request_kwargs,
|
|
||||||
)
|
|
||||||
if deployment is None:
|
if deployment is None:
|
||||||
verbose_router_logger.info(
|
verbose_router_logger.info(
|
||||||
f"get_available_deployment for model: {model}, No deployment available"
|
f"get_available_deployment for model: {model}, No deployment available"
|
||||||
|
@ -3360,6 +3367,23 @@ class Router:
|
||||||
):
|
):
|
||||||
return retry_policy.ContentPolicyViolationErrorRetries
|
return retry_policy.ContentPolicyViolationErrorRetries
|
||||||
|
|
||||||
|
def _initialize_alerting(self):
|
||||||
|
from litellm.integrations.slack_alerting import SlackAlerting
|
||||||
|
|
||||||
|
router_alerting_config: AlertingConfig = self.alerting_config
|
||||||
|
|
||||||
|
_slack_alerting_logger = SlackAlerting(
|
||||||
|
alerting_threshold=router_alerting_config.alerting_threshold,
|
||||||
|
alerting=["slack"],
|
||||||
|
default_webhook_url=router_alerting_config.webhook_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
litellm.callbacks.append(_slack_alerting_logger)
|
||||||
|
litellm.success_callback.append(
|
||||||
|
_slack_alerting_logger.response_taking_too_long_callback
|
||||||
|
)
|
||||||
|
print("\033[94m\nInitialized Alerting for litellm.Router\033[0m\n") # noqa
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
self.cache.flush_cache()
|
self.cache.flush_cache()
|
||||||
|
|
|
@ -40,7 +40,7 @@ class LowestCostLoggingHandler(CustomLogger):
|
||||||
self.router_cache = router_cache
|
self.router_cache = router_cache
|
||||||
self.model_list = model_list
|
self.model_list = model_list
|
||||||
|
|
||||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
"""
|
"""
|
||||||
Update usage on success
|
Update usage on success
|
||||||
|
@ -90,7 +90,11 @@ class LowestCostLoggingHandler(CustomLogger):
|
||||||
# Update usage
|
# Update usage
|
||||||
# ------------
|
# ------------
|
||||||
|
|
||||||
request_count_dict = self.router_cache.get_cache(key=cost_key) or {}
|
request_count_dict = (
|
||||||
|
await self.router_cache.async_get_cache(key=cost_key) or {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# check local result first
|
||||||
|
|
||||||
if id not in request_count_dict:
|
if id not in request_count_dict:
|
||||||
request_count_dict[id] = {}
|
request_count_dict[id] = {}
|
||||||
|
@ -111,7 +115,9 @@ class LowestCostLoggingHandler(CustomLogger):
|
||||||
request_count_dict[id][precise_minute].get("rpm", 0) + 1
|
request_count_dict[id][precise_minute].get("rpm", 0) + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
self.router_cache.set_cache(key=cost_key, value=request_count_dict)
|
await self.router_cache.async_set_cache(
|
||||||
|
key=cost_key, value=request_count_dict
|
||||||
|
)
|
||||||
|
|
||||||
### TESTING ###
|
### TESTING ###
|
||||||
if self.test_flag:
|
if self.test_flag:
|
||||||
|
@ -172,7 +178,9 @@ class LowestCostLoggingHandler(CustomLogger):
|
||||||
# Update usage
|
# Update usage
|
||||||
# ------------
|
# ------------
|
||||||
|
|
||||||
request_count_dict = self.router_cache.get_cache(key=cost_key) or {}
|
request_count_dict = (
|
||||||
|
await self.router_cache.async_get_cache(key=cost_key) or {}
|
||||||
|
)
|
||||||
|
|
||||||
if id not in request_count_dict:
|
if id not in request_count_dict:
|
||||||
request_count_dict[id] = {}
|
request_count_dict[id] = {}
|
||||||
|
@ -189,7 +197,7 @@ class LowestCostLoggingHandler(CustomLogger):
|
||||||
request_count_dict[id][precise_minute].get("rpm", 0) + 1
|
request_count_dict[id][precise_minute].get("rpm", 0) + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
self.router_cache.set_cache(
|
await self.router_cache.async_set_cache(
|
||||||
key=cost_key, value=request_count_dict
|
key=cost_key, value=request_count_dict
|
||||||
) # reset map within window
|
) # reset map within window
|
||||||
|
|
||||||
|
@ -200,7 +208,7 @@ class LowestCostLoggingHandler(CustomLogger):
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_available_deployments(
|
async def async_get_available_deployments(
|
||||||
self,
|
self,
|
||||||
model_group: str,
|
model_group: str,
|
||||||
healthy_deployments: list,
|
healthy_deployments: list,
|
||||||
|
@ -213,7 +221,7 @@ class LowestCostLoggingHandler(CustomLogger):
|
||||||
"""
|
"""
|
||||||
cost_key = f"{model_group}_map"
|
cost_key = f"{model_group}_map"
|
||||||
|
|
||||||
request_count_dict = self.router_cache.get_cache(key=cost_key) or {}
|
request_count_dict = await self.router_cache.async_get_cache(key=cost_key) or {}
|
||||||
|
|
||||||
# -----------------------
|
# -----------------------
|
||||||
# Find lowest used model
|
# Find lowest used model
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -18,6 +18,10 @@ from unittest.mock import patch, MagicMock
|
||||||
from litellm.utils import get_api_base
|
from litellm.utils import get_api_base
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.integrations.slack_alerting import SlackAlerting, DeploymentMetrics
|
from litellm.integrations.slack_alerting import SlackAlerting, DeploymentMetrics
|
||||||
|
import unittest.mock
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
import pytest
|
||||||
|
from litellm.router import AlertingConfig, Router
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -313,3 +317,45 @@ async def test_daily_reports_redis_cache_scheduler():
|
||||||
|
|
||||||
# second call - expect empty
|
# second call - expect empty
|
||||||
await slack_alerting._run_scheduler_helper(llm_router=router)
|
await slack_alerting._run_scheduler_helper(llm_router=router)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skip(reason="Local test. Test if slack alerts are sent.")
|
||||||
|
async def test_send_llm_exception_to_slack():
|
||||||
|
from litellm.router import AlertingConfig
|
||||||
|
|
||||||
|
# on async success
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": "bad_key",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-5-good",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
alerting_config=AlertingConfig(
|
||||||
|
alerting_threshold=0.5, webhook_url=os.getenv("SLACK_WEBHOOK_URL")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await router.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await router.acompletion(
|
||||||
|
model="gpt-5-good",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
|
|
@ -2168,9 +2168,9 @@ def test_completion_replicate_vicuna():
|
||||||
|
|
||||||
def test_replicate_custom_prompt_dict():
|
def test_replicate_custom_prompt_dict():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
model_name = "replicate/meta/llama-2-70b-chat"
|
model_name = "replicate/meta/llama-2-7b"
|
||||||
litellm.register_prompt_template(
|
litellm.register_prompt_template(
|
||||||
model="replicate/meta/llama-2-70b-chat",
|
model="replicate/meta/llama-2-7b",
|
||||||
initial_prompt_value="You are a good assistant", # [OPTIONAL]
|
initial_prompt_value="You are a good assistant", # [OPTIONAL]
|
||||||
roles={
|
roles={
|
||||||
"system": {
|
"system": {
|
||||||
|
@ -2200,6 +2200,7 @@ def test_replicate_custom_prompt_dict():
|
||||||
repetition_penalty=0.1,
|
repetition_penalty=0.1,
|
||||||
num_retries=3,
|
num_retries=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
except litellm.APIError as e:
|
except litellm.APIError as e:
|
||||||
pass
|
pass
|
||||||
except litellm.APIConnectionError as e:
|
except litellm.APIConnectionError as e:
|
||||||
|
@ -3017,6 +3018,21 @@ async def test_acompletion_gemini():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# Deepseek tests
|
||||||
|
def test_completion_deepseek():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
model_name = "deepseek/deepseek-chat"
|
||||||
|
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
try:
|
||||||
|
response = completion(model=model_name, messages=messages)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
except litellm.APIError as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
# Palm tests
|
# Palm tests
|
||||||
def test_completion_palm():
|
def test_completion_palm():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
import litellm
|
||||||
|
from litellm import get_optional_params
|
||||||
|
|
||||||
|
litellm.add_function_to_prompt = True
|
||||||
|
optional_params = get_optional_params(
|
||||||
|
tools= [{'type': 'function', 'function': {'description': 'Get the current weather in a given location', 'name': 'get_current_weather', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state, e.g. San Francisco, CA'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}],
|
||||||
|
tool_choice= 'auto',
|
||||||
|
)
|
||||||
|
assert optional_params is not None
|
|
@ -20,7 +20,8 @@ from litellm.caching import DualCache
|
||||||
### UNIT TESTS FOR cost ROUTING ###
|
### UNIT TESTS FOR cost ROUTING ###
|
||||||
|
|
||||||
|
|
||||||
def test_get_available_deployments():
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_available_deployments():
|
||||||
test_cache = DualCache()
|
test_cache = DualCache()
|
||||||
model_list = [
|
model_list = [
|
||||||
{
|
{
|
||||||
|
@ -40,7 +41,7 @@ def test_get_available_deployments():
|
||||||
model_group = "gpt-3.5-turbo"
|
model_group = "gpt-3.5-turbo"
|
||||||
|
|
||||||
## CHECK WHAT'S SELECTED ##
|
## CHECK WHAT'S SELECTED ##
|
||||||
selected_model = lowest_cost_logger.get_available_deployments(
|
selected_model = await lowest_cost_logger.async_get_available_deployments(
|
||||||
model_group=model_group, healthy_deployments=model_list
|
model_group=model_group, healthy_deployments=model_list
|
||||||
)
|
)
|
||||||
print("selected model: ", selected_model)
|
print("selected model: ", selected_model)
|
||||||
|
@ -48,7 +49,8 @@ def test_get_available_deployments():
|
||||||
assert selected_model["model_info"]["id"] == "groq-llama"
|
assert selected_model["model_info"]["id"] == "groq-llama"
|
||||||
|
|
||||||
|
|
||||||
def test_get_available_deployments_custom_price():
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_available_deployments_custom_price():
|
||||||
from litellm._logging import verbose_router_logger
|
from litellm._logging import verbose_router_logger
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -89,7 +91,7 @@ def test_get_available_deployments_custom_price():
|
||||||
model_group = "gpt-3.5-turbo"
|
model_group = "gpt-3.5-turbo"
|
||||||
|
|
||||||
## CHECK WHAT'S SELECTED ##
|
## CHECK WHAT'S SELECTED ##
|
||||||
selected_model = lowest_cost_logger.get_available_deployments(
|
selected_model = await lowest_cost_logger.async_get_available_deployments(
|
||||||
model_group=model_group, healthy_deployments=model_list
|
model_group=model_group, healthy_deployments=model_list
|
||||||
)
|
)
|
||||||
print("selected model: ", selected_model)
|
print("selected model: ", selected_model)
|
||||||
|
@ -142,7 +144,7 @@ async def _deploy(lowest_cost_logger, deployment_id, tokens_used, duration):
|
||||||
response_obj = {"usage": {"total_tokens": tokens_used}}
|
response_obj = {"usage": {"total_tokens": tokens_used}}
|
||||||
time.sleep(duration)
|
time.sleep(duration)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
lowest_cost_logger.log_success_event(
|
await lowest_cost_logger.async_log_success_event(
|
||||||
response_obj=response_obj,
|
response_obj=response_obj,
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
@ -150,14 +152,11 @@ async def _deploy(lowest_cost_logger, deployment_id, tokens_used, duration):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _gather_deploy(all_deploys):
|
|
||||||
return await asyncio.gather(*[_deploy(*t) for t in all_deploys])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"ans_rpm", [1, 5]
|
"ans_rpm", [1, 5]
|
||||||
) # 1 should produce nothing, 10 should select first
|
) # 1 should produce nothing, 10 should select first
|
||||||
def test_get_available_endpoints_tpm_rpm_check_async(ans_rpm):
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_available_endpoints_tpm_rpm_check_async(ans_rpm):
|
||||||
"""
|
"""
|
||||||
Pass in list of 2 valid models
|
Pass in list of 2 valid models
|
||||||
|
|
||||||
|
@ -193,9 +192,13 @@ def test_get_available_endpoints_tpm_rpm_check_async(ans_rpm):
|
||||||
model_group = "gpt-3.5-turbo"
|
model_group = "gpt-3.5-turbo"
|
||||||
d1 = [(lowest_cost_logger, "1234", 50, 0.01)] * non_ans_rpm
|
d1 = [(lowest_cost_logger, "1234", 50, 0.01)] * non_ans_rpm
|
||||||
d2 = [(lowest_cost_logger, "5678", 50, 0.01)] * non_ans_rpm
|
d2 = [(lowest_cost_logger, "5678", 50, 0.01)] * non_ans_rpm
|
||||||
asyncio.run(_gather_deploy([*d1, *d2]))
|
|
||||||
|
await asyncio.gather(*[_deploy(*t) for t in [*d1, *d2]])
|
||||||
|
|
||||||
|
asyncio.sleep(3)
|
||||||
|
|
||||||
## CHECK WHAT'S SELECTED ##
|
## CHECK WHAT'S SELECTED ##
|
||||||
d_ans = lowest_cost_logger.get_available_deployments(
|
d_ans = await lowest_cost_logger.async_get_available_deployments(
|
||||||
model_group=model_group, healthy_deployments=model_list
|
model_group=model_group, healthy_deployments=model_list
|
||||||
)
|
)
|
||||||
assert (d_ans and d_ans["model_info"]["id"]) == ans
|
assert (d_ans and d_ans["model_info"]["id"]) == ans
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
import sys, os
|
import sys, os
|
||||||
import traceback
|
import traceback
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
@ -10,10 +11,10 @@ sys.path.insert(
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
import pytest
|
||||||
import litellm
|
import litellm
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
## for ollama we can't test making the completion call
|
## for ollama we can't test making the completion call
|
||||||
from litellm.utils import get_optional_params, get_llm_provider
|
from litellm.utils import EmbeddingResponse, get_optional_params, get_llm_provider
|
||||||
|
|
||||||
|
|
||||||
def test_get_ollama_params():
|
def test_get_ollama_params():
|
||||||
|
@ -58,3 +59,50 @@ def test_ollama_json_mode():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# test_ollama_json_mode()
|
# test_ollama_json_mode()
|
||||||
|
|
||||||
|
|
||||||
|
mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-text")
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
"litellm.llms.ollama.ollama_embeddings",
|
||||||
|
return_value=mock_ollama_embedding_response,
|
||||||
|
)
|
||||||
|
def test_ollama_embeddings(mock_embeddings):
|
||||||
|
# assert that ollama_embeddings is called with the right parameters
|
||||||
|
try:
|
||||||
|
embeddings = litellm.embedding(model="ollama/nomic-embed-text", input=["hello world"])
|
||||||
|
print(embeddings)
|
||||||
|
mock_embeddings.assert_called_once_with(
|
||||||
|
api_base="http://localhost:11434",
|
||||||
|
model="nomic-embed-text",
|
||||||
|
prompts=["hello world"],
|
||||||
|
optional_params=mock.ANY,
|
||||||
|
logging_obj=mock.ANY,
|
||||||
|
model_response=mock.ANY,
|
||||||
|
encoding=mock.ANY,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
# test_ollama_embeddings()
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
"litellm.llms.ollama.ollama_aembeddings",
|
||||||
|
return_value=mock_ollama_embedding_response,
|
||||||
|
)
|
||||||
|
def test_ollama_aembeddings(mock_aembeddings):
|
||||||
|
# assert that ollama_aembeddings is called with the right parameters
|
||||||
|
try:
|
||||||
|
embeddings = asyncio.run(litellm.aembedding(model="ollama/nomic-embed-text", input=["hello world"]))
|
||||||
|
print(embeddings)
|
||||||
|
mock_aembeddings.assert_called_once_with(
|
||||||
|
api_base="http://localhost:11434",
|
||||||
|
model="nomic-embed-text",
|
||||||
|
prompts=["hello world"],
|
||||||
|
optional_params=mock.ANY,
|
||||||
|
logging_obj=mock.ANY,
|
||||||
|
model_response=mock.ANY,
|
||||||
|
encoding=mock.ANY,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
# test_ollama_aembeddings()
|
||||||
|
|
|
@ -24,6 +24,14 @@
|
||||||
|
|
||||||
# asyncio.run(test_ollama_aembeddings())
|
# asyncio.run(test_ollama_aembeddings())
|
||||||
|
|
||||||
|
# def test_ollama_embeddings():
|
||||||
|
# litellm.set_verbose = True
|
||||||
|
# input = "The food was delicious and the waiter..."
|
||||||
|
# response = litellm.embedding(model="ollama/mistral", input=input)
|
||||||
|
# print(response)
|
||||||
|
|
||||||
|
# test_ollama_embeddings()
|
||||||
|
|
||||||
# def test_ollama_streaming():
|
# def test_ollama_streaming():
|
||||||
# try:
|
# try:
|
||||||
# litellm.set_verbose = False
|
# litellm.set_verbose = False
|
||||||
|
|
|
@ -340,3 +340,19 @@ class RetryPolicy(BaseModel):
|
||||||
RateLimitErrorRetries: Optional[int] = None
|
RateLimitErrorRetries: Optional[int] = None
|
||||||
ContentPolicyViolationErrorRetries: Optional[int] = None
|
ContentPolicyViolationErrorRetries: Optional[int] = None
|
||||||
InternalServerErrorRetries: Optional[int] = None
|
InternalServerErrorRetries: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AlertingConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Use this configure alerting for the router. Receive alerts on the following events
|
||||||
|
- LLM API Exceptions
|
||||||
|
- LLM Responses Too Slow
|
||||||
|
- LLM Requests Hanging
|
||||||
|
|
||||||
|
Args:
|
||||||
|
webhook_url: str - webhook url for alerting, slack provides a webhook url to send alerts to
|
||||||
|
alerting_threshold: Optional[float] = None - threshold for slow / hanging llm responses (in seconds)
|
||||||
|
"""
|
||||||
|
|
||||||
|
webhook_url: str
|
||||||
|
alerting_threshold: Optional[float] = 300
|
||||||
|
|
|
@ -4836,7 +4836,7 @@ def get_optional_params(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# retrieve all parameters passed to the function
|
# retrieve all parameters passed to the function
|
||||||
passed_params = locals()
|
passed_params = locals().copy()
|
||||||
special_params = passed_params.pop("kwargs")
|
special_params = passed_params.pop("kwargs")
|
||||||
for k, v in special_params.items():
|
for k, v in special_params.items():
|
||||||
if k.startswith("aws_") and (
|
if k.startswith("aws_") and (
|
||||||
|
@ -4929,9 +4929,11 @@ def get_optional_params(
|
||||||
and custom_llm_provider != "anyscale"
|
and custom_llm_provider != "anyscale"
|
||||||
and custom_llm_provider != "together_ai"
|
and custom_llm_provider != "together_ai"
|
||||||
and custom_llm_provider != "groq"
|
and custom_llm_provider != "groq"
|
||||||
|
and custom_llm_provider != "deepseek"
|
||||||
and custom_llm_provider != "mistral"
|
and custom_llm_provider != "mistral"
|
||||||
and custom_llm_provider != "anthropic"
|
and custom_llm_provider != "anthropic"
|
||||||
and custom_llm_provider != "cohere_chat"
|
and custom_llm_provider != "cohere_chat"
|
||||||
|
and custom_llm_provider != "cohere"
|
||||||
and custom_llm_provider != "bedrock"
|
and custom_llm_provider != "bedrock"
|
||||||
and custom_llm_provider != "ollama_chat"
|
and custom_llm_provider != "ollama_chat"
|
||||||
):
|
):
|
||||||
|
@ -4956,7 +4958,7 @@ def get_optional_params(
|
||||||
litellm.add_function_to_prompt
|
litellm.add_function_to_prompt
|
||||||
): # if user opts to add it to prompt instead
|
): # if user opts to add it to prompt instead
|
||||||
optional_params["functions_unsupported_model"] = non_default_params.pop(
|
optional_params["functions_unsupported_model"] = non_default_params.pop(
|
||||||
"tools", non_default_params.pop("functions")
|
"tools", non_default_params.pop("functions", None)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise UnsupportedParamsError(
|
raise UnsupportedParamsError(
|
||||||
|
@ -5614,6 +5616,29 @@ def get_optional_params(
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
optional_params["seed"] = seed
|
optional_params["seed"] = seed
|
||||||
|
|
||||||
|
elif custom_llm_provider == "deepseek":
|
||||||
|
supported_params = get_supported_openai_params(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
|
||||||
|
if frequency_penalty is not None:
|
||||||
|
optional_params["frequency_penalty"] = frequency_penalty
|
||||||
|
if max_tokens is not None:
|
||||||
|
optional_params["max_tokens"] = max_tokens
|
||||||
|
if presence_penalty is not None:
|
||||||
|
optional_params["presence_penalty"] = presence_penalty
|
||||||
|
if stop is not None:
|
||||||
|
optional_params["stop"] = stop
|
||||||
|
if stream is not None:
|
||||||
|
optional_params["stream"] = stream
|
||||||
|
if temperature is not None:
|
||||||
|
optional_params["temperature"] = temperature
|
||||||
|
if logprobs is not None:
|
||||||
|
optional_params["logprobs"] = logprobs
|
||||||
|
if top_logprobs is not None:
|
||||||
|
optional_params["top_logprobs"] = top_logprobs
|
||||||
|
|
||||||
elif custom_llm_provider == "openrouter":
|
elif custom_llm_provider == "openrouter":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
@ -5946,6 +5971,19 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
||||||
"response_format",
|
"response_format",
|
||||||
"seed",
|
"seed",
|
||||||
]
|
]
|
||||||
|
elif custom_llm_provider == "deepseek":
|
||||||
|
return [
|
||||||
|
# https://platform.deepseek.com/api-docs/api/create-chat-completion
|
||||||
|
"frequency_penalty",
|
||||||
|
"max_tokens",
|
||||||
|
"presence_penalty",
|
||||||
|
"stop",
|
||||||
|
"stream",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"logprobs",
|
||||||
|
"top_logprobs",
|
||||||
|
]
|
||||||
elif custom_llm_provider == "cohere":
|
elif custom_llm_provider == "cohere":
|
||||||
return [
|
return [
|
||||||
"stream",
|
"stream",
|
||||||
|
@ -6239,8 +6277,12 @@ def get_llm_provider(
|
||||||
# groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
|
# groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
|
||||||
api_base = "https://api.groq.com/openai/v1"
|
api_base = "https://api.groq.com/openai/v1"
|
||||||
dynamic_api_key = get_secret("GROQ_API_KEY")
|
dynamic_api_key = get_secret("GROQ_API_KEY")
|
||||||
|
elif custom_llm_provider == "deepseek":
|
||||||
|
# deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1
|
||||||
|
api_base = "https://api.deepseek.com/v1"
|
||||||
|
dynamic_api_key = get_secret("DEEPSEEK_API_KEY")
|
||||||
elif custom_llm_provider == "fireworks_ai":
|
elif custom_llm_provider == "fireworks_ai":
|
||||||
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
|
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1
|
||||||
if not model.startswith("accounts/fireworks/models"):
|
if not model.startswith("accounts/fireworks/models"):
|
||||||
model = f"accounts/fireworks/models/{model}"
|
model = f"accounts/fireworks/models/{model}"
|
||||||
api_base = "https://api.fireworks.ai/inference/v1"
|
api_base = "https://api.fireworks.ai/inference/v1"
|
||||||
|
@ -6303,6 +6345,9 @@ def get_llm_provider(
|
||||||
elif endpoint == "api.groq.com/openai/v1":
|
elif endpoint == "api.groq.com/openai/v1":
|
||||||
custom_llm_provider = "groq"
|
custom_llm_provider = "groq"
|
||||||
dynamic_api_key = get_secret("GROQ_API_KEY")
|
dynamic_api_key = get_secret("GROQ_API_KEY")
|
||||||
|
elif endpoint == "api.deepseek.com/v1":
|
||||||
|
custom_llm_provider = "deepseek"
|
||||||
|
dynamic_api_key = get_secret("DEEPSEEK_API_KEY")
|
||||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||||
|
|
||||||
# check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.)
|
# check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.)
|
||||||
|
@ -6901,6 +6946,11 @@ def validate_environment(model: Optional[str] = None) -> dict:
|
||||||
keys_in_environment = True
|
keys_in_environment = True
|
||||||
else:
|
else:
|
||||||
missing_keys.append("GROQ_API_KEY")
|
missing_keys.append("GROQ_API_KEY")
|
||||||
|
elif custom_llm_provider == "deepseek":
|
||||||
|
if "DEEPSEEK_API_KEY" in os.environ:
|
||||||
|
keys_in_environment = True
|
||||||
|
else:
|
||||||
|
missing_keys.append("DEEPSEEK_API_KEY")
|
||||||
elif custom_llm_provider == "mistral":
|
elif custom_llm_provider == "mistral":
|
||||||
if "MISTRAL_API_KEY" in os.environ:
|
if "MISTRAL_API_KEY" in os.environ:
|
||||||
keys_in_environment = True
|
keys_in_environment = True
|
||||||
|
|
|
@ -739,6 +739,24 @@
|
||||||
"litellm_provider": "mistral",
|
"litellm_provider": "mistral",
|
||||||
"mode": "embedding"
|
"mode": "embedding"
|
||||||
},
|
},
|
||||||
|
"deepseek-chat": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 32000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.00000014,
|
||||||
|
"output_cost_per_token": 0.00000028,
|
||||||
|
"litellm_provider": "deepseek",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"deepseek-coder": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 16000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.00000014,
|
||||||
|
"output_cost_per_token": 0.00000028,
|
||||||
|
"litellm_provider": "deepseek",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
"groq/llama2-70b-4096": {
|
"groq/llama2-70b-4096": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 4096,
|
"max_input_tokens": 4096,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "1.36.1"
|
version = "1.36.2"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -80,7 +80,7 @@ requires = ["poetry-core", "wheel"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.commitizen]
|
[tool.commitizen]
|
||||||
version = "1.36.1"
|
version = "1.36.2"
|
||||||
version_files = [
|
version_files = [
|
||||||
"pyproject.toml:^version"
|
"pyproject.toml:^version"
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue