forked from phoenix/litellm-mirror
Merge branch 'main' into litellm_webhook_support
This commit is contained in:
commit
707cf24472
19 changed files with 832 additions and 90 deletions
|
@ -151,3 +151,19 @@ response = image_generation(
|
||||||
)
|
)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## VertexAI - Image Generation Models
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
Use this for image generation models on VertexAI
|
||||||
|
|
||||||
|
```python
|
||||||
|
response = litellm.image_generation(
|
||||||
|
prompt="An olympic size swimming pool",
|
||||||
|
model="vertex_ai/imagegeneration@006",
|
||||||
|
vertex_ai_project="adroit-crow-413218",
|
||||||
|
vertex_ai_location="us-central1",
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
```
|
|
@ -508,6 +508,31 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
|
||||||
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
|
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
|
||||||
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
|
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
|
||||||
|
|
||||||
|
## Image Generation Models
|
||||||
|
|
||||||
|
Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
response = await litellm.aimage_generation(
|
||||||
|
prompt="An olympic size swimming pool",
|
||||||
|
model="vertex_ai/imagegeneration@006",
|
||||||
|
vertex_ai_project="adroit-crow-413218",
|
||||||
|
vertex_ai_location="us-central1",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Generating multiple images**
|
||||||
|
|
||||||
|
Use the `n` parameter to pass how many images you want generated
|
||||||
|
```python
|
||||||
|
response = await litellm.aimage_generation(
|
||||||
|
prompt="An olympic size swimming pool",
|
||||||
|
model="vertex_ai/imagegeneration@006",
|
||||||
|
vertex_ai_project="adroit-crow-413218",
|
||||||
|
vertex_ai_location="us-central1",
|
||||||
|
n=1,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
## Extra
|
## Extra
|
||||||
|
|
||||||
|
|
|
@ -25,26 +25,45 @@ class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observabilit
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
#### ASYNC ####
|
|
||||||
|
|
||||||
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def async_log_pre_api_call(self, model, messages, kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
|
||||||
pass
|
|
||||||
|
|
||||||
#### CALL HOOKS - proxy only ####
|
#### CALL HOOKS - proxy only ####
|
||||||
|
|
||||||
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal["completion", "embeddings"]):
|
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal[
|
||||||
|
"completion",
|
||||||
|
"text_completion",
|
||||||
|
"embeddings",
|
||||||
|
"image_generation",
|
||||||
|
"moderation",
|
||||||
|
"audio_transcription",
|
||||||
|
]) -> Optional[dict, str, Exception]:
|
||||||
data["model"] = "my-new-model"
|
data["model"] = "my-new-model"
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
async def async_post_call_failure_hook(
|
||||||
|
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_post_call_success_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
response,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_moderation_hook( # call made in parallel to llm api call
|
||||||
|
self,
|
||||||
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_post_call_streaming_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
response: str,
|
||||||
|
):
|
||||||
|
pass
|
||||||
proxy_handler_instance = MyCustomHandler()
|
proxy_handler_instance = MyCustomHandler()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -191,3 +210,99 @@ general_settings:
|
||||||
**Result**
|
**Result**
|
||||||
|
|
||||||
<Image img={require('../../img/end_user_enforcement.png')}/>
|
<Image img={require('../../img/end_user_enforcement.png')}/>
|
||||||
|
|
||||||
|
## Advanced - Return rejected message as response
|
||||||
|
|
||||||
|
For chat completions and text completion calls, you can return a rejected message as a user response.
|
||||||
|
|
||||||
|
Do this by returning a string. LiteLLM takes care of returning the response in the correct format depending on the endpoint and if it's streaming/non-streaming.
|
||||||
|
|
||||||
|
For non-chat/text completion endpoints, this response is returned as a 400 status code exception.
|
||||||
|
|
||||||
|
|
||||||
|
### 1. Create Custom Handler
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
import litellm
|
||||||
|
from litellm.utils import get_formatted_prompt
|
||||||
|
|
||||||
|
# This file includes the custom callbacks for LiteLLM Proxy
|
||||||
|
# Once defined, these can be passed in proxy_config.yaml
|
||||||
|
class MyCustomHandler(CustomLogger):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
#### CALL HOOKS - proxy only ####
|
||||||
|
|
||||||
|
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal[
|
||||||
|
"completion",
|
||||||
|
"text_completion",
|
||||||
|
"embeddings",
|
||||||
|
"image_generation",
|
||||||
|
"moderation",
|
||||||
|
"audio_transcription",
|
||||||
|
]) -> Optional[dict, str, Exception]:
|
||||||
|
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type)
|
||||||
|
|
||||||
|
if "Hello world" in formatted_prompt:
|
||||||
|
return "This is an invalid response"
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
proxy_handler_instance = MyCustomHandler()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Update config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### 3. Test it!
|
||||||
|
|
||||||
|
```shell
|
||||||
|
$ litellm /path/to/config.yaml
|
||||||
|
```
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--data ' {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello world"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
**Expected Response**
|
||||||
|
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-d00bbede-2d90-4618-bf7b-11a1c23cf360",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"content": "This is an invalid response.", # 👈 REJECTED RESPONSE
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1716234198,
|
||||||
|
"model": null,
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": null,
|
||||||
|
"usage": {}
|
||||||
|
}
|
||||||
|
```
|
|
@ -724,6 +724,9 @@ from .utils import (
|
||||||
get_supported_openai_params,
|
get_supported_openai_params,
|
||||||
get_api_base,
|
get_api_base,
|
||||||
get_first_chars_messages,
|
get_first_chars_messages,
|
||||||
|
ModelResponse,
|
||||||
|
ImageResponse,
|
||||||
|
ImageObject,
|
||||||
)
|
)
|
||||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||||
from .llms.anthropic import AnthropicConfig
|
from .llms.anthropic import AnthropicConfig
|
||||||
|
|
|
@ -177,6 +177,32 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
# sub class of bad request error - meant to help us catch guardrails-related errors on proxy.
|
||||||
|
class RejectedRequestError(BadRequestError): # type: ignore
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message,
|
||||||
|
model,
|
||||||
|
llm_provider,
|
||||||
|
request_data: dict,
|
||||||
|
litellm_debug_info: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.status_code = 400
|
||||||
|
self.message = message
|
||||||
|
self.model = model
|
||||||
|
self.llm_provider = llm_provider
|
||||||
|
self.litellm_debug_info = litellm_debug_info
|
||||||
|
self.request_data = request_data
|
||||||
|
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||||
|
response = httpx.Response(status_code=500, request=request)
|
||||||
|
super().__init__(
|
||||||
|
message=self.message,
|
||||||
|
model=self.model, # type: ignore
|
||||||
|
llm_provider=self.llm_provider, # type: ignore
|
||||||
|
response=response,
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
class ContentPolicyViolationError(BadRequestError): # type: ignore
|
class ContentPolicyViolationError(BadRequestError): # type: ignore
|
||||||
# Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}}
|
# Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}}
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -4,7 +4,6 @@ import dotenv, os
|
||||||
|
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
from typing import Literal, Union, Optional
|
from typing import Literal, Union, Optional
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
@ -64,8 +63,17 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
cache: DualCache,
|
cache: DualCache,
|
||||||
data: dict,
|
data: dict,
|
||||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
call_type: Literal[
|
||||||
):
|
"completion",
|
||||||
|
"text_completion",
|
||||||
|
"embeddings",
|
||||||
|
"image_generation",
|
||||||
|
"moderation",
|
||||||
|
"audio_transcription",
|
||||||
|
],
|
||||||
|
) -> Optional[
|
||||||
|
Union[Exception, str, dict]
|
||||||
|
]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def async_post_call_failure_hook(
|
async def async_post_call_failure_hook(
|
||||||
|
|
|
@ -871,27 +871,37 @@ Model Info:
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
"""Log deployment latency"""
|
"""Log deployment latency"""
|
||||||
if "daily_reports" in self.alert_types:
|
try:
|
||||||
model_id = (
|
if "daily_reports" in self.alert_types:
|
||||||
kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
|
model_id = (
|
||||||
)
|
kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
|
||||||
response_s: timedelta = end_time - start_time
|
|
||||||
|
|
||||||
final_value = response_s
|
|
||||||
total_tokens = 0
|
|
||||||
|
|
||||||
if isinstance(response_obj, litellm.ModelResponse):
|
|
||||||
completion_tokens = response_obj.usage.completion_tokens
|
|
||||||
final_value = float(response_s.total_seconds() / completion_tokens)
|
|
||||||
|
|
||||||
await self.async_update_daily_reports(
|
|
||||||
DeploymentMetrics(
|
|
||||||
id=model_id,
|
|
||||||
failed_request=False,
|
|
||||||
latency_per_output_token=final_value,
|
|
||||||
updated_at=litellm.utils.get_utc_datetime(),
|
|
||||||
)
|
)
|
||||||
|
response_s: timedelta = end_time - start_time
|
||||||
|
|
||||||
|
final_value = response_s
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
if isinstance(response_obj, litellm.ModelResponse):
|
||||||
|
completion_tokens = response_obj.usage.completion_tokens
|
||||||
|
if completion_tokens is not None and completion_tokens > 0:
|
||||||
|
final_value = float(
|
||||||
|
response_s.total_seconds() / completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.async_update_daily_reports(
|
||||||
|
DeploymentMetrics(
|
||||||
|
id=model_id,
|
||||||
|
failed_request=False,
|
||||||
|
latency_per_output_token=final_value,
|
||||||
|
updated_at=litellm.utils.get_utc_datetime(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.error(
|
||||||
|
"[Non-Blocking Error] Slack Alerting: Got error in logging LLM deployment latency: ",
|
||||||
|
e,
|
||||||
)
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
"""Log failure + deployment latency"""
|
"""Log failure + deployment latency"""
|
||||||
|
|
|
@ -96,7 +96,7 @@ class MistralConfig:
|
||||||
safe_prompt: Optional[bool] = None,
|
safe_prompt: Optional[bool] = None,
|
||||||
response_format: Optional[dict] = None,
|
response_format: Optional[dict] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
locals_ = locals()
|
locals_ = locals().copy()
|
||||||
for key, value in locals_.items():
|
for key, value in locals_.items():
|
||||||
if key != "self" and value is not None:
|
if key != "self" and value is not None:
|
||||||
setattr(self.__class__, key, value)
|
setattr(self.__class__, key, value)
|
||||||
|
@ -211,7 +211,7 @@ class OpenAIConfig:
|
||||||
temperature: Optional[int] = None,
|
temperature: Optional[int] = None,
|
||||||
top_p: Optional[int] = None,
|
top_p: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
locals_ = locals()
|
locals_ = locals().copy()
|
||||||
for key, value in locals_.items():
|
for key, value in locals_.items():
|
||||||
if key != "self" and value is not None:
|
if key != "self" and value is not None:
|
||||||
setattr(self.__class__, key, value)
|
setattr(self.__class__, key, value)
|
||||||
|
@ -335,7 +335,7 @@ class OpenAITextCompletionConfig:
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
locals_ = locals()
|
locals_ = locals().copy()
|
||||||
for key, value in locals_.items():
|
for key, value in locals_.items():
|
||||||
if key != "self" and value is not None:
|
if key != "self" and value is not None:
|
||||||
setattr(self.__class__, key, value)
|
setattr(self.__class__, key, value)
|
||||||
|
|
224
litellm/llms/vertex_httpx.py
Normal file
224
litellm/llms/vertex_httpx.py
Normal file
|
@ -0,0 +1,224 @@
|
||||||
|
import os, types
|
||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
import requests # type: ignore
|
||||||
|
import time
|
||||||
|
from typing import Callable, Optional, Union, List, Any, Tuple
|
||||||
|
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
||||||
|
import litellm, uuid
|
||||||
|
import httpx, inspect # type: ignore
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from .base import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIError(Exception):
|
||||||
|
def __init__(self, status_code, message):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
self.request = httpx.Request(
|
||||||
|
method="POST", url=" https://cloud.google.com/vertex-ai/"
|
||||||
|
)
|
||||||
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||||
|
super().__init__(
|
||||||
|
self.message
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
class VertexLLM(BaseLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.access_token: Optional[str] = None
|
||||||
|
self.refresh_token: Optional[str] = None
|
||||||
|
self._credentials: Optional[Any] = None
|
||||||
|
self.project_id: Optional[str] = None
|
||||||
|
self.async_handler: Optional[AsyncHTTPHandler] = None
|
||||||
|
|
||||||
|
def load_auth(self) -> Tuple[Any, str]:
|
||||||
|
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||||
|
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
||||||
|
import google.auth as google_auth
|
||||||
|
|
||||||
|
credentials, project_id = google_auth.default(
|
||||||
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||||
|
)
|
||||||
|
|
||||||
|
credentials.refresh(Request())
|
||||||
|
|
||||||
|
if not project_id:
|
||||||
|
raise ValueError("Could not resolve project_id")
|
||||||
|
|
||||||
|
if not isinstance(project_id, str):
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected project_id to be a str but got {type(project_id)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return credentials, project_id
|
||||||
|
|
||||||
|
def refresh_auth(self, credentials: Any) -> None:
|
||||||
|
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||||
|
|
||||||
|
credentials.refresh(Request())
|
||||||
|
|
||||||
|
def _prepare_request(self, request: httpx.Request) -> None:
|
||||||
|
access_token = self._ensure_access_token()
|
||||||
|
|
||||||
|
if request.headers.get("Authorization"):
|
||||||
|
# already authenticated, nothing for us to do
|
||||||
|
return
|
||||||
|
|
||||||
|
request.headers["Authorization"] = f"Bearer {access_token}"
|
||||||
|
|
||||||
|
def _ensure_access_token(self) -> str:
|
||||||
|
if self.access_token is not None:
|
||||||
|
return self.access_token
|
||||||
|
|
||||||
|
if not self._credentials:
|
||||||
|
self._credentials, project_id = self.load_auth()
|
||||||
|
if not self.project_id:
|
||||||
|
self.project_id = project_id
|
||||||
|
else:
|
||||||
|
self.refresh_auth(self._credentials)
|
||||||
|
|
||||||
|
if not self._credentials.token:
|
||||||
|
raise RuntimeError("Could not resolve API token from the environment")
|
||||||
|
|
||||||
|
assert isinstance(self._credentials.token, str)
|
||||||
|
return self._credentials.token
|
||||||
|
|
||||||
|
def image_generation(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
vertex_project: str,
|
||||||
|
vertex_location: str,
|
||||||
|
model: Optional[
|
||||||
|
str
|
||||||
|
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
optional_params: Optional[dict] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
logging_obj=None,
|
||||||
|
model_response=None,
|
||||||
|
aimg_generation=False,
|
||||||
|
):
|
||||||
|
if aimg_generation == True:
|
||||||
|
response = self.aimage_generation(
|
||||||
|
prompt=prompt,
|
||||||
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
model=model,
|
||||||
|
client=client,
|
||||||
|
optional_params=optional_params,
|
||||||
|
timeout=timeout,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
model_response=model_response,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def aimage_generation(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
vertex_project: str,
|
||||||
|
vertex_location: str,
|
||||||
|
model_response: litellm.ImageResponse,
|
||||||
|
model: Optional[
|
||||||
|
str
|
||||||
|
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
optional_params: Optional[dict] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
logging_obj=None,
|
||||||
|
):
|
||||||
|
response = None
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
_httpx_timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = _httpx_timeout
|
||||||
|
else:
|
||||||
|
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
|
||||||
|
self.async_handler = AsyncHTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
self.async_handler = client # type: ignore
|
||||||
|
|
||||||
|
# make POST request to
|
||||||
|
# https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
|
||||||
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
|
||||||
|
|
||||||
|
"""
|
||||||
|
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
|
||||||
|
curl -X POST \
|
||||||
|
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
|
||||||
|
-H "Content-Type: application/json; charset=utf-8" \
|
||||||
|
-d {
|
||||||
|
"instances": [
|
||||||
|
{
|
||||||
|
"prompt": "a cat"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"parameters": {
|
||||||
|
"sampleCount": 1
|
||||||
|
}
|
||||||
|
} \
|
||||||
|
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
|
||||||
|
"""
|
||||||
|
auth_header = self._ensure_access_token()
|
||||||
|
optional_params = optional_params or {
|
||||||
|
"sampleCount": 1
|
||||||
|
} # default optional params
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"instances": [{"prompt": prompt}],
|
||||||
|
"parameters": optional_params,
|
||||||
|
}
|
||||||
|
|
||||||
|
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=prompt,
|
||||||
|
api_key=None,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": optional_params,
|
||||||
|
"request_str": request_str,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await self.async_handler.post(
|
||||||
|
url=url,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
|
"Authorization": f"Bearer {auth_header}",
|
||||||
|
},
|
||||||
|
data=json.dumps(request_data),
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||||
|
"""
|
||||||
|
Vertex AI Image generation response example:
|
||||||
|
{
|
||||||
|
"predictions": [
|
||||||
|
{
|
||||||
|
"bytesBase64Encoded": "BASE64_IMG_BYTES",
|
||||||
|
"mimeType": "image/png"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"mimeType": "image/png",
|
||||||
|
"bytesBase64Encoded": "BASE64_IMG_BYTES"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
_json_response = response.json()
|
||||||
|
_predictions = _json_response["predictions"]
|
||||||
|
|
||||||
|
_response_data: List[litellm.ImageObject] = []
|
||||||
|
for _prediction in _predictions:
|
||||||
|
_bytes_base64_encoded = _prediction["bytesBase64Encoded"]
|
||||||
|
image_object = litellm.ImageObject(b64_json=_bytes_base64_encoded)
|
||||||
|
_response_data.append(image_object)
|
||||||
|
|
||||||
|
model_response.data = _response_data
|
||||||
|
|
||||||
|
return model_response
|
|
@ -79,6 +79,7 @@ from .llms.anthropic_text import AnthropicTextCompletion
|
||||||
from .llms.huggingface_restapi import Huggingface
|
from .llms.huggingface_restapi import Huggingface
|
||||||
from .llms.predibase import PredibaseChatCompletion
|
from .llms.predibase import PredibaseChatCompletion
|
||||||
from .llms.bedrock_httpx import BedrockLLM
|
from .llms.bedrock_httpx import BedrockLLM
|
||||||
|
from .llms.vertex_httpx import VertexLLM
|
||||||
from .llms.triton import TritonChatCompletion
|
from .llms.triton import TritonChatCompletion
|
||||||
from .llms.prompt_templates.factory import (
|
from .llms.prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
|
@ -118,6 +119,7 @@ huggingface = Huggingface()
|
||||||
predibase_chat_completions = PredibaseChatCompletion()
|
predibase_chat_completions = PredibaseChatCompletion()
|
||||||
triton_chat_completions = TritonChatCompletion()
|
triton_chat_completions = TritonChatCompletion()
|
||||||
bedrock_chat_completion = BedrockLLM()
|
bedrock_chat_completion = BedrockLLM()
|
||||||
|
vertex_chat_completion = VertexLLM()
|
||||||
####### COMPLETION ENDPOINTS ################
|
####### COMPLETION ENDPOINTS ################
|
||||||
|
|
||||||
|
|
||||||
|
@ -3854,6 +3856,36 @@ def image_generation(
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
aimg_generation=aimg_generation,
|
aimg_generation=aimg_generation,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "vertex_ai":
|
||||||
|
vertex_ai_project = (
|
||||||
|
optional_params.pop("vertex_project", None)
|
||||||
|
or optional_params.pop("vertex_ai_project", None)
|
||||||
|
or litellm.vertex_project
|
||||||
|
or get_secret("VERTEXAI_PROJECT")
|
||||||
|
)
|
||||||
|
vertex_ai_location = (
|
||||||
|
optional_params.pop("vertex_location", None)
|
||||||
|
or optional_params.pop("vertex_ai_location", None)
|
||||||
|
or litellm.vertex_location
|
||||||
|
or get_secret("VERTEXAI_LOCATION")
|
||||||
|
)
|
||||||
|
vertex_credentials = (
|
||||||
|
optional_params.pop("vertex_credentials", None)
|
||||||
|
or optional_params.pop("vertex_ai_credentials", None)
|
||||||
|
or get_secret("VERTEXAI_CREDENTIALS")
|
||||||
|
)
|
||||||
|
model_response = vertex_chat_completion.image_generation(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
timeout=timeout,
|
||||||
|
logging_obj=litellm_logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model_response=model_response,
|
||||||
|
vertex_project=vertex_ai_project,
|
||||||
|
vertex_location=vertex_ai_location,
|
||||||
|
aimg_generation=aimg_generation,
|
||||||
|
)
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
## Map to OpenAI Exception
|
## Map to OpenAI Exception
|
||||||
|
|
|
@ -18,9 +18,3 @@ model_list:
|
||||||
|
|
||||||
router_settings:
|
router_settings:
|
||||||
enable_pre_call_checks: true
|
enable_pre_call_checks: true
|
||||||
|
|
||||||
general_settings:
|
|
||||||
alerting: ["slack", "webhook"]
|
|
||||||
|
|
||||||
environment_variables:
|
|
||||||
WEBHOOK_URL: https://webhook.site/6ab090e8-c55f-4a23-b075-3209f5c57906
|
|
||||||
|
|
|
@ -251,6 +251,10 @@ class LiteLLMPromptInjectionParams(LiteLLMBase):
|
||||||
llm_api_name: Optional[str] = None
|
llm_api_name: Optional[str] = None
|
||||||
llm_api_system_prompt: Optional[str] = None
|
llm_api_system_prompt: Optional[str] = None
|
||||||
llm_api_fail_call_string: Optional[str] = None
|
llm_api_fail_call_string: Optional[str] = None
|
||||||
|
reject_as_response: Optional[bool] = Field(
|
||||||
|
default=False,
|
||||||
|
description="Return rejected request error message as a string to the user. Default behaviour is to raise an exception.",
|
||||||
|
)
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def check_llm_api_params(cls, values):
|
def check_llm_api_params(cls, values):
|
||||||
|
|
|
@ -146,6 +146,7 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
||||||
try:
|
try:
|
||||||
assert call_type in [
|
assert call_type in [
|
||||||
"completion",
|
"completion",
|
||||||
|
"text_completion",
|
||||||
"embeddings",
|
"embeddings",
|
||||||
"image_generation",
|
"image_generation",
|
||||||
"moderation",
|
"moderation",
|
||||||
|
@ -192,6 +193,15 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
|
|
||||||
|
if (
|
||||||
|
e.status_code == 400
|
||||||
|
and isinstance(e.detail, dict)
|
||||||
|
and "error" in e.detail
|
||||||
|
and self.prompt_injection_params is not None
|
||||||
|
and self.prompt_injection_params.reject_as_response
|
||||||
|
):
|
||||||
|
return e.detail["error"]
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
|
@ -124,6 +124,7 @@ from litellm.proxy.auth.auth_checks import (
|
||||||
get_actual_routes,
|
get_actual_routes,
|
||||||
)
|
)
|
||||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||||
|
from litellm.exceptions import RejectedRequestError
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from litellm._version import version
|
from litellm._version import version
|
||||||
|
@ -3649,7 +3650,6 @@ async def chat_completion(
|
||||||
):
|
):
|
||||||
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
||||||
data = {}
|
data = {}
|
||||||
check_request_disconnected = None
|
|
||||||
try:
|
try:
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
body_str = body.decode()
|
body_str = body.decode()
|
||||||
|
@ -3767,8 +3767,8 @@ async def chat_completion(
|
||||||
|
|
||||||
data["litellm_logging_obj"] = logging_obj
|
data["litellm_logging_obj"] = logging_obj
|
||||||
|
|
||||||
### CALL HOOKS ### - modify incoming data before calling the model
|
### CALL HOOKS ### - modify/reject incoming data before calling the model
|
||||||
data = await proxy_logging_obj.pre_call_hook(
|
data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
||||||
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -3832,9 +3832,6 @@ async def chat_completion(
|
||||||
*tasks
|
*tasks
|
||||||
) # run the moderation check in parallel to the actual llm api call
|
) # run the moderation check in parallel to the actual llm api call
|
||||||
|
|
||||||
check_request_disconnected = asyncio.create_task(
|
|
||||||
check_request_disconnection(request, llm_responses)
|
|
||||||
)
|
|
||||||
responses = await llm_responses
|
responses = await llm_responses
|
||||||
|
|
||||||
response = responses[1]
|
response = responses[1]
|
||||||
|
@ -3886,6 +3883,40 @@ async def chat_completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
except RejectedRequestError as e:
|
||||||
|
_data = e.request_data
|
||||||
|
_data["litellm_status"] = "fail" # used for alerting
|
||||||
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
original_exception=e,
|
||||||
|
request_data=_data,
|
||||||
|
)
|
||||||
|
_chat_response = litellm.ModelResponse()
|
||||||
|
_chat_response.choices[0].message.content = e.message # type: ignore
|
||||||
|
|
||||||
|
if data.get("stream", None) is not None and data["stream"] == True:
|
||||||
|
_iterator = litellm.utils.ModelResponseIterator(
|
||||||
|
model_response=_chat_response, convert_to_delta=True
|
||||||
|
)
|
||||||
|
_streaming_response = litellm.CustomStreamWrapper(
|
||||||
|
completion_stream=_iterator,
|
||||||
|
model=data.get("model", ""),
|
||||||
|
custom_llm_provider="cached_response",
|
||||||
|
logging_obj=data.get("litellm_logging_obj", None),
|
||||||
|
)
|
||||||
|
selected_data_generator = select_data_generator(
|
||||||
|
response=_streaming_response,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
request_data=_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
selected_data_generator,
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
|
_usage = litellm.Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||||
|
_chat_response.usage = _usage # type: ignore
|
||||||
|
return _chat_response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
data["litellm_status"] = "fail" # used for alerting
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
@ -3916,9 +3947,6 @@ async def chat_completion(
|
||||||
param=getattr(e, "param", "None"),
|
param=getattr(e, "param", "None"),
|
||||||
code=getattr(e, "status_code", 500),
|
code=getattr(e, "status_code", 500),
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
if check_request_disconnected is not None:
|
|
||||||
check_request_disconnected.cancel()
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
@ -3945,7 +3973,6 @@ async def completion(
|
||||||
):
|
):
|
||||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||||
data = {}
|
data = {}
|
||||||
check_request_disconnected = None
|
|
||||||
try:
|
try:
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
body_str = body.decode()
|
body_str = body.decode()
|
||||||
|
@ -4004,8 +4031,8 @@ async def completion(
|
||||||
data["model"] = litellm.model_alias_map[data["model"]]
|
data["model"] = litellm.model_alias_map[data["model"]]
|
||||||
|
|
||||||
### CALL HOOKS ### - modify incoming data before calling the model
|
### CALL HOOKS ### - modify incoming data before calling the model
|
||||||
data = await proxy_logging_obj.pre_call_hook(
|
data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
||||||
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
|
||||||
)
|
)
|
||||||
|
|
||||||
### ROUTE THE REQUESTs ###
|
### ROUTE THE REQUESTs ###
|
||||||
|
@ -4045,9 +4072,6 @@ async def completion(
|
||||||
+ data.get("model", "")
|
+ data.get("model", "")
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
check_request_disconnected = asyncio.create_task(
|
|
||||||
check_request_disconnection(request, llm_response)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Await the llm_response task
|
# Await the llm_response task
|
||||||
response = await llm_response
|
response = await llm_response
|
||||||
|
@ -4091,6 +4115,46 @@ async def completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
except RejectedRequestError as e:
|
||||||
|
_data = e.request_data
|
||||||
|
_data["litellm_status"] = "fail" # used for alerting
|
||||||
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
original_exception=e,
|
||||||
|
request_data=_data,
|
||||||
|
)
|
||||||
|
if _data.get("stream", None) is not None and _data["stream"] == True:
|
||||||
|
_chat_response = litellm.ModelResponse()
|
||||||
|
_usage = litellm.Usage(
|
||||||
|
prompt_tokens=0,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=0,
|
||||||
|
)
|
||||||
|
_chat_response.usage = _usage # type: ignore
|
||||||
|
_chat_response.choices[0].message.content = e.message # type: ignore
|
||||||
|
_iterator = litellm.utils.ModelResponseIterator(
|
||||||
|
model_response=_chat_response, convert_to_delta=True
|
||||||
|
)
|
||||||
|
_streaming_response = litellm.TextCompletionStreamWrapper(
|
||||||
|
completion_stream=_iterator,
|
||||||
|
model=_data.get("model", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
selected_data_generator = select_data_generator(
|
||||||
|
response=_streaming_response,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
request_data=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
selected_data_generator,
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_response = litellm.TextCompletionResponse()
|
||||||
|
_response.choices[0].text = e.message
|
||||||
|
return _response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
data["litellm_status"] = "fail" # used for alerting
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
|
@ -4112,9 +4176,6 @@ async def completion(
|
||||||
param=getattr(e, "param", "None"),
|
param=getattr(e, "param", "None"),
|
||||||
code=getattr(e, "status_code", 500),
|
code=getattr(e, "status_code", 500),
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
if check_request_disconnected is not None:
|
|
||||||
check_request_disconnected.cancel()
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
@ -7761,6 +7822,12 @@ async def team_info(
|
||||||
team_info = await prisma_client.get_data(
|
team_info = await prisma_client.get_data(
|
||||||
team_id=team_id, table_name="team", query_type="find_unique"
|
team_id=team_id, table_name="team", query_type="find_unique"
|
||||||
)
|
)
|
||||||
|
if team_info is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail={"message": f"Team not found, passed team id: {team_id}."},
|
||||||
|
)
|
||||||
|
|
||||||
## GET ALL KEYS ##
|
## GET ALL KEYS ##
|
||||||
keys = await prisma_client.get_data(
|
keys = await prisma_client.get_data(
|
||||||
team_id=team_id,
|
team_id=team_id,
|
||||||
|
@ -8993,9 +9060,25 @@ async def google_login(request: Request):
|
||||||
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
|
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
|
||||||
Example:
|
Example:
|
||||||
"""
|
"""
|
||||||
|
global premium_user
|
||||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||||
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||||
|
|
||||||
|
####### Check if user is a Enterprise / Premium User #######
|
||||||
|
if (
|
||||||
|
microsoft_client_id is not None
|
||||||
|
or google_client_id is not None
|
||||||
|
or generic_client_id is not None
|
||||||
|
):
|
||||||
|
if premium_user != True:
|
||||||
|
raise ProxyException(
|
||||||
|
message="You must be a LiteLLM Enterprise user to use SSO. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this",
|
||||||
|
type="auth_error",
|
||||||
|
param="premium_user",
|
||||||
|
code=status.HTTP_403_FORBIDDEN,
|
||||||
|
)
|
||||||
|
|
||||||
# get url from request
|
# get url from request
|
||||||
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
||||||
ui_username = os.getenv("UI_USERNAME")
|
ui_username = os.getenv("UI_USERNAME")
|
||||||
|
|
|
@ -19,8 +19,18 @@ from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||||
from litellm.proxy.hooks.parallel_request_limiter import (
|
from litellm.proxy.hooks.parallel_request_limiter import (
|
||||||
_PROXY_MaxParallelRequestsHandler,
|
_PROXY_MaxParallelRequestsHandler,
|
||||||
)
|
)
|
||||||
|
from litellm.exceptions import RejectedRequestError
|
||||||
from litellm._service_logger import ServiceLogging, ServiceTypes
|
from litellm._service_logger import ServiceLogging, ServiceTypes
|
||||||
from litellm import ModelResponse, EmbeddingResponse, ImageResponse
|
from litellm import (
|
||||||
|
ModelResponse,
|
||||||
|
EmbeddingResponse,
|
||||||
|
ImageResponse,
|
||||||
|
TranscriptionResponse,
|
||||||
|
TextCompletionResponse,
|
||||||
|
CustomStreamWrapper,
|
||||||
|
TextCompletionStreamWrapper,
|
||||||
|
)
|
||||||
|
from litellm.utils import ModelResponseIterator
|
||||||
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
|
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
|
||||||
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
|
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
|
||||||
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
|
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
|
||||||
|
@ -33,6 +43,7 @@ from email.mime.text import MIMEText
|
||||||
from email.mime.multipart import MIMEMultipart
|
from email.mime.multipart import MIMEMultipart
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from litellm.integrations.slack_alerting import SlackAlerting
|
from litellm.integrations.slack_alerting import SlackAlerting
|
||||||
|
from typing_extensions import overload
|
||||||
|
|
||||||
|
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
|
@ -132,7 +143,13 @@ class ProxyLogging:
|
||||||
alerting_args=alerting_args,
|
alerting_args=alerting_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "daily_reports" in self.alert_types:
|
if (
|
||||||
|
self.alerting is not None
|
||||||
|
and "slack" in self.alerting
|
||||||
|
and "daily_reports" in self.alert_types
|
||||||
|
):
|
||||||
|
# NOTE: ENSURE we only add callbacks when alerting is on
|
||||||
|
# We should NOT add callbacks when alerting is off
|
||||||
litellm.callbacks.append(self.slack_alerting_instance) # type: ignore
|
litellm.callbacks.append(self.slack_alerting_instance) # type: ignore
|
||||||
|
|
||||||
if redis_cache is not None:
|
if redis_cache is not None:
|
||||||
|
@ -177,18 +194,20 @@ class ProxyLogging:
|
||||||
)
|
)
|
||||||
litellm.utils.set_callbacks(callback_list=callback_list)
|
litellm.utils.set_callbacks(callback_list=callback_list)
|
||||||
|
|
||||||
|
# The actual implementation of the function
|
||||||
async def pre_call_hook(
|
async def pre_call_hook(
|
||||||
self,
|
self,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
data: dict,
|
data: dict,
|
||||||
call_type: Literal[
|
call_type: Literal[
|
||||||
"completion",
|
"completion",
|
||||||
|
"text_completion",
|
||||||
"embeddings",
|
"embeddings",
|
||||||
"image_generation",
|
"image_generation",
|
||||||
"moderation",
|
"moderation",
|
||||||
"audio_transcription",
|
"audio_transcription",
|
||||||
],
|
],
|
||||||
):
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
|
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
|
||||||
|
|
||||||
|
@ -215,8 +234,25 @@ class ProxyLogging:
|
||||||
call_type=call_type,
|
call_type=call_type,
|
||||||
)
|
)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
data = response
|
if isinstance(response, Exception):
|
||||||
|
raise response
|
||||||
|
elif isinstance(response, dict):
|
||||||
|
data = response
|
||||||
|
elif isinstance(response, str):
|
||||||
|
if (
|
||||||
|
call_type == "completion"
|
||||||
|
or call_type == "text_completion"
|
||||||
|
):
|
||||||
|
raise RejectedRequestError(
|
||||||
|
message=response,
|
||||||
|
model=data.get("model", ""),
|
||||||
|
llm_provider="",
|
||||||
|
request_data=data,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail={"error": response}
|
||||||
|
)
|
||||||
print_verbose(f"final data being sent to {call_type} call: {data}")
|
print_verbose(f"final data being sent to {call_type} call: {data}")
|
||||||
return data
|
return data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -441,7 +477,7 @@ class ProxyLogging:
|
||||||
|
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.alerting_handler(
|
self.alerting_handler(
|
||||||
message=f"LLM API call failed: {exception_str}",
|
message=f"LLM API call failed: `{exception_str}`",
|
||||||
level="High",
|
level="High",
|
||||||
alert_type="llm_exceptions",
|
alert_type="llm_exceptions",
|
||||||
request_data=request_data,
|
request_data=request_data,
|
||||||
|
|
|
@ -1923,10 +1923,28 @@ class Router:
|
||||||
metadata = kwargs.get("litellm_params", {}).get("metadata", None)
|
metadata = kwargs.get("litellm_params", {}).get("metadata", None)
|
||||||
_model_info = kwargs.get("litellm_params", {}).get("model_info", {})
|
_model_info = kwargs.get("litellm_params", {}).get("model_info", {})
|
||||||
|
|
||||||
|
exception_response = getattr(exception, "response", {})
|
||||||
|
exception_headers = getattr(exception_response, "headers", None)
|
||||||
|
_time_to_cooldown = self.cooldown_time
|
||||||
|
|
||||||
|
if exception_headers is not None:
|
||||||
|
|
||||||
|
_time_to_cooldown = (
|
||||||
|
litellm.utils._get_retry_after_from_exception_header(
|
||||||
|
response_headers=exception_headers
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if _time_to_cooldown < 0:
|
||||||
|
# if the response headers did not read it -> set to default cooldown time
|
||||||
|
_time_to_cooldown = self.cooldown_time
|
||||||
|
|
||||||
if isinstance(_model_info, dict):
|
if isinstance(_model_info, dict):
|
||||||
deployment_id = _model_info.get("id", None)
|
deployment_id = _model_info.get("id", None)
|
||||||
self._set_cooldown_deployments(
|
self._set_cooldown_deployments(
|
||||||
exception_status=exception_status, deployment=deployment_id
|
exception_status=exception_status,
|
||||||
|
deployment=deployment_id,
|
||||||
|
time_to_cooldown=_time_to_cooldown,
|
||||||
) # setting deployment_id in cooldown deployments
|
) # setting deployment_id in cooldown deployments
|
||||||
if custom_llm_provider:
|
if custom_llm_provider:
|
||||||
model_name = f"{custom_llm_provider}/{model_name}"
|
model_name = f"{custom_llm_provider}/{model_name}"
|
||||||
|
@ -2026,7 +2044,10 @@ class Router:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _set_cooldown_deployments(
|
def _set_cooldown_deployments(
|
||||||
self, exception_status: Union[str, int], deployment: Optional[str] = None
|
self,
|
||||||
|
exception_status: Union[str, int],
|
||||||
|
deployment: Optional[str] = None,
|
||||||
|
time_to_cooldown: Optional[float] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute
|
Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute
|
||||||
|
@ -2053,6 +2074,8 @@ class Router:
|
||||||
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}"
|
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}"
|
||||||
)
|
)
|
||||||
cooldown_time = self.cooldown_time or 1
|
cooldown_time = self.cooldown_time or 1
|
||||||
|
if time_to_cooldown is not None:
|
||||||
|
cooldown_time = time_to_cooldown
|
||||||
|
|
||||||
if isinstance(exception_status, str):
|
if isinstance(exception_status, str):
|
||||||
try:
|
try:
|
||||||
|
@ -2090,7 +2113,9 @@ class Router:
|
||||||
)
|
)
|
||||||
|
|
||||||
self.send_deployment_cooldown_alert(
|
self.send_deployment_cooldown_alert(
|
||||||
deployment_id=deployment, exception_status=exception_status
|
deployment_id=deployment,
|
||||||
|
exception_status=exception_status,
|
||||||
|
cooldown_time=cooldown_time,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.failed_calls.set_cache(
|
self.failed_calls.set_cache(
|
||||||
|
@ -3751,7 +3776,10 @@ class Router:
|
||||||
print("\033[94m\nInitialized Alerting for litellm.Router\033[0m\n") # noqa
|
print("\033[94m\nInitialized Alerting for litellm.Router\033[0m\n") # noqa
|
||||||
|
|
||||||
def send_deployment_cooldown_alert(
|
def send_deployment_cooldown_alert(
|
||||||
self, deployment_id: str, exception_status: Union[str, int]
|
self,
|
||||||
|
deployment_id: str,
|
||||||
|
exception_status: Union[str, int],
|
||||||
|
cooldown_time: float,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||||
|
@ -3775,7 +3803,7 @@ class Router:
|
||||||
)
|
)
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
proxy_logging_obj.slack_alerting_instance.send_alert(
|
proxy_logging_obj.slack_alerting_instance.send_alert(
|
||||||
message=f"Router: Cooling down Deployment:\nModel Name: {_model_name}\nAPI Base: {_api_base}\n{self.cooldown_time} seconds. Got exception: {str(exception_status)}. Change 'cooldown_time' + 'allowed_fails' under 'Router Settings' on proxy UI, or via config - https://docs.litellm.ai/docs/proxy/reliability#fallbacks--retries--timeouts--cooldowns",
|
message=f"Router: Cooling down Deployment:\nModel Name: `{_model_name}`\nAPI Base: `{_api_base}`\nCooldown Time: `{cooldown_time} seconds`\nException Status Code: `{str(exception_status)}`\n\nChange 'cooldown_time' + 'allowed_fails' under 'Router Settings' on proxy UI, or via config - https://docs.litellm.ai/docs/proxy/reliability#fallbacks--retries--timeouts--cooldowns",
|
||||||
alert_type="cooldown_deployment",
|
alert_type="cooldown_deployment",
|
||||||
level="Low",
|
level="Low",
|
||||||
)
|
)
|
||||||
|
|
|
@ -206,6 +206,7 @@ def test_completion_bedrock_claude_sts_client_auth():
|
||||||
|
|
||||||
# test_completion_bedrock_claude_sts_client_auth()
|
# test_completion_bedrock_claude_sts_client_auth()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="We don't have Circle CI OIDC credentials as yet")
|
@pytest.mark.skip(reason="We don't have Circle CI OIDC credentials as yet")
|
||||||
def test_completion_bedrock_claude_sts_oidc_auth():
|
def test_completion_bedrock_claude_sts_oidc_auth():
|
||||||
print("\ncalling bedrock claude with oidc auth")
|
print("\ncalling bedrock claude with oidc auth")
|
||||||
|
@ -244,7 +245,7 @@ def test_bedrock_extra_headers():
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
temperature=0.78,
|
temperature=0.78,
|
||||||
extra_headers={"x-key": "x_key_value"}
|
extra_headers={"x-key": "x_key_value"},
|
||||||
)
|
)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
assert len(response.choices) > 0
|
assert len(response.choices) > 0
|
||||||
|
@ -259,7 +260,7 @@ def test_bedrock_claude_3():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
data = {
|
data = {
|
||||||
"max_tokens": 2000,
|
"max_tokens": 100,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"temperature": 0.3,
|
"temperature": 0.3,
|
||||||
"messages": [
|
"messages": [
|
||||||
|
@ -282,6 +283,7 @@ def test_bedrock_claude_3():
|
||||||
}
|
}
|
||||||
response: ModelResponse = completion(
|
response: ModelResponse = completion(
|
||||||
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
num_retries=3,
|
||||||
# messages=messages,
|
# messages=messages,
|
||||||
# max_tokens=10,
|
# max_tokens=10,
|
||||||
# temperature=0.78,
|
# temperature=0.78,
|
||||||
|
|
|
@ -169,3 +169,36 @@ async def test_aimage_generation_bedrock_with_optional_params():
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aimage_generation_vertex_ai():
|
||||||
|
from test_amazing_vertex_completion import load_vertex_ai_credentials
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
try:
|
||||||
|
response = await litellm.aimage_generation(
|
||||||
|
prompt="An olympic size swimming pool",
|
||||||
|
model="vertex_ai/imagegeneration@006",
|
||||||
|
vertex_ai_project="adroit-crow-413218",
|
||||||
|
vertex_ai_location="us-central1",
|
||||||
|
n=1,
|
||||||
|
)
|
||||||
|
assert response.data is not None
|
||||||
|
assert len(response.data) > 0
|
||||||
|
|
||||||
|
for d in response.data:
|
||||||
|
assert isinstance(d, litellm.ImageObject)
|
||||||
|
print("data in response.data", d)
|
||||||
|
assert d.b64_json is not None
|
||||||
|
except litellm.RateLimitError as e:
|
||||||
|
pass
|
||||||
|
except litellm.ContentPolicyViolationError:
|
||||||
|
pass # Azure randomly raises these errors - skip when they occur
|
||||||
|
except Exception as e:
|
||||||
|
if "Your task failed as a result of our safety system." in str(e):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
117
litellm/utils.py
117
litellm/utils.py
|
@ -965,10 +965,54 @@ class TextCompletionResponse(OpenAIObject):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageObject(OpenAIObject):
|
||||||
|
"""
|
||||||
|
Represents the url or the content of an image generated by the OpenAI API.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
b64_json: The base64-encoded JSON of the generated image, if response_format is b64_json.
|
||||||
|
url: The URL of the generated image, if response_format is url (default).
|
||||||
|
revised_prompt: The prompt that was used to generate the image, if there was any revision to the prompt.
|
||||||
|
|
||||||
|
https://platform.openai.com/docs/api-reference/images/object
|
||||||
|
"""
|
||||||
|
|
||||||
|
b64_json: Optional[str] = None
|
||||||
|
url: Optional[str] = None
|
||||||
|
revised_prompt: Optional[str] = None
|
||||||
|
|
||||||
|
def __init__(self, b64_json=None, url=None, revised_prompt=None):
|
||||||
|
|
||||||
|
super().__init__(b64_json=b64_json, url=url, revised_prompt=revised_prompt)
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
# Define custom behavior for the 'in' operator
|
||||||
|
return hasattr(self, key)
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||||
|
return getattr(self, key, default)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
# Allow dictionary-style access to attributes
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
# Allow dictionary-style assignment of attributes
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
def json(self, **kwargs):
|
||||||
|
try:
|
||||||
|
return self.model_dump() # noqa
|
||||||
|
except:
|
||||||
|
# if using pydantic v1
|
||||||
|
return self.dict()
|
||||||
|
|
||||||
|
|
||||||
class ImageResponse(OpenAIObject):
|
class ImageResponse(OpenAIObject):
|
||||||
created: Optional[int] = None
|
created: Optional[int] = None
|
||||||
|
|
||||||
data: Optional[list] = None
|
data: Optional[List[ImageObject]] = None
|
||||||
|
|
||||||
usage: Optional[dict] = None
|
usage: Optional[dict] = None
|
||||||
|
|
||||||
|
@ -4902,6 +4946,14 @@ def get_optional_params_image_gen(
|
||||||
width, height = size.split("x")
|
width, height = size.split("x")
|
||||||
optional_params["width"] = int(width)
|
optional_params["width"] = int(width)
|
||||||
optional_params["height"] = int(height)
|
optional_params["height"] = int(height)
|
||||||
|
elif custom_llm_provider == "vertex_ai":
|
||||||
|
supported_params = ["n"]
|
||||||
|
"""
|
||||||
|
All params here: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
|
||||||
|
"""
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
if n is not None:
|
||||||
|
optional_params["sampleCount"] = int(n)
|
||||||
|
|
||||||
for k in passed_params.keys():
|
for k in passed_params.keys():
|
||||||
if k not in default_params.keys():
|
if k not in default_params.keys():
|
||||||
|
@ -6440,6 +6492,7 @@ def get_formatted_prompt(
|
||||||
"image_generation",
|
"image_generation",
|
||||||
"audio_transcription",
|
"audio_transcription",
|
||||||
"moderation",
|
"moderation",
|
||||||
|
"text_completion",
|
||||||
],
|
],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -6452,6 +6505,8 @@ def get_formatted_prompt(
|
||||||
for m in data["messages"]:
|
for m in data["messages"]:
|
||||||
if "content" in m and isinstance(m["content"], str):
|
if "content" in m and isinstance(m["content"], str):
|
||||||
prompt += m["content"]
|
prompt += m["content"]
|
||||||
|
elif call_type == "text_completion":
|
||||||
|
prompt = data["prompt"]
|
||||||
elif call_type == "embedding" or call_type == "moderation":
|
elif call_type == "embedding" or call_type == "moderation":
|
||||||
if isinstance(data["input"], str):
|
if isinstance(data["input"], str):
|
||||||
prompt = data["input"]
|
prompt = data["input"]
|
||||||
|
@ -8019,11 +8074,8 @@ def _should_retry(status_code: int):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _calculate_retry_after(
|
def _get_retry_after_from_exception_header(
|
||||||
remaining_retries: int,
|
|
||||||
max_retries: int,
|
|
||||||
response_headers: Optional[httpx.Headers] = None,
|
response_headers: Optional[httpx.Headers] = None,
|
||||||
min_timeout: int = 0,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Reimplementation of openai's calculate retry after, since that one can't be imported.
|
Reimplementation of openai's calculate retry after, since that one can't be imported.
|
||||||
|
@ -8049,10 +8101,20 @@ def _calculate_retry_after(
|
||||||
retry_after = int(retry_date - time.time())
|
retry_after = int(retry_date - time.time())
|
||||||
else:
|
else:
|
||||||
retry_after = -1
|
retry_after = -1
|
||||||
|
return retry_after
|
||||||
|
|
||||||
except Exception:
|
except Exception as e:
|
||||||
retry_after = -1
|
retry_after = -1
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_retry_after(
|
||||||
|
remaining_retries: int,
|
||||||
|
max_retries: int,
|
||||||
|
response_headers: Optional[httpx.Headers] = None,
|
||||||
|
min_timeout: int = 0,
|
||||||
|
):
|
||||||
|
retry_after = _get_retry_after_from_exception_header(response_headers)
|
||||||
|
|
||||||
# If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
|
# If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
|
||||||
if 0 < retry_after <= 60:
|
if 0 < retry_after <= 60:
|
||||||
return retry_after
|
return retry_after
|
||||||
|
@ -8263,18 +8325,18 @@ def exception_type(
|
||||||
_deployment = _metadata.get("deployment")
|
_deployment = _metadata.get("deployment")
|
||||||
extra_information = f"\nModel: {model}"
|
extra_information = f"\nModel: {model}"
|
||||||
if _api_base:
|
if _api_base:
|
||||||
extra_information += f"\nAPI Base: {_api_base}"
|
extra_information += f"\nAPI Base: `{_api_base}`"
|
||||||
if messages and len(messages) > 0:
|
if messages and len(messages) > 0:
|
||||||
extra_information += f"\nMessages: {messages}"
|
extra_information += f"\nMessages: `{messages}`"
|
||||||
|
|
||||||
if _model_group is not None:
|
if _model_group is not None:
|
||||||
extra_information += f"\nmodel_group: {_model_group}\n"
|
extra_information += f"\nmodel_group: `{_model_group}`\n"
|
||||||
if _deployment is not None:
|
if _deployment is not None:
|
||||||
extra_information += f"\ndeployment: {_deployment}\n"
|
extra_information += f"\ndeployment: `{_deployment}`\n"
|
||||||
if _vertex_project is not None:
|
if _vertex_project is not None:
|
||||||
extra_information += f"\nvertex_project: {_vertex_project}\n"
|
extra_information += f"\nvertex_project: `{_vertex_project}`\n"
|
||||||
if _vertex_location is not None:
|
if _vertex_location is not None:
|
||||||
extra_information += f"\nvertex_location: {_vertex_location}\n"
|
extra_information += f"\nvertex_location: `{_vertex_location}`\n"
|
||||||
|
|
||||||
# on litellm proxy add key name + team to exceptions
|
# on litellm proxy add key name + team to exceptions
|
||||||
extra_information = _add_key_name_and_team_to_alert(
|
extra_information = _add_key_name_and_team_to_alert(
|
||||||
|
@ -12187,3 +12249,34 @@ def _add_key_name_and_team_to_alert(request_info: str, metadata: dict) -> str:
|
||||||
return request_info
|
return request_info
|
||||||
except:
|
except:
|
||||||
return request_info
|
return request_info
|
||||||
|
|
||||||
|
|
||||||
|
class ModelResponseIterator:
|
||||||
|
def __init__(self, model_response: ModelResponse, convert_to_delta: bool = False):
|
||||||
|
if convert_to_delta == True:
|
||||||
|
self.model_response = ModelResponse(stream=True)
|
||||||
|
_delta = self.model_response.choices[0].delta # type: ignore
|
||||||
|
_delta.content = model_response.choices[0].message.content # type: ignore
|
||||||
|
else:
|
||||||
|
self.model_response = model_response
|
||||||
|
self.is_done = False
|
||||||
|
|
||||||
|
# Sync iterator
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.is_done:
|
||||||
|
raise StopIteration
|
||||||
|
self.is_done = True
|
||||||
|
return self.model_response
|
||||||
|
|
||||||
|
# Async iterator
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
if self.is_done:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
self.is_done = True
|
||||||
|
return self.model_response
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue