mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Merge branch 'main' into litellm_fix_in_mem_usage
This commit is contained in:
commit
8e3a073323
17 changed files with 1332 additions and 126 deletions
|
@ -152,3 +152,104 @@ litellm_remaining_team_budget_metric{team_alias="QA Prod Bot",team_id="de35b29e-
|
|||
```
|
||||
|
||||
|
||||
### Dynamic TPM Allocation
|
||||
|
||||
Prevent projects from gobbling too much quota.
|
||||
|
||||
Dynamically allocate TPM quota to api keys, based on active keys in that minute.
|
||||
|
||||
1. Setup config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: my-fake-model
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
api_key: my-fake-key
|
||||
mock_response: hello-world
|
||||
tpm: 60
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["dynamic_rate_limiter"]
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234 # OR set `LITELLM_MASTER_KEY=".."` in your .env
|
||||
database_url: postgres://.. # OR set `DATABASE_URL=".."` in your .env
|
||||
```
|
||||
|
||||
2. Start proxy
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```python
|
||||
"""
|
||||
- Run 2 concurrent teams calling same model
|
||||
- model has 60 TPM
|
||||
- Mock response returns 30 total tokens / request
|
||||
- Each team will only be able to make 1 request per minute
|
||||
"""
|
||||
"""
|
||||
- Run 2 concurrent teams calling same model
|
||||
- model has 60 TPM
|
||||
- Mock response returns 30 total tokens / request
|
||||
- Each team will only be able to make 1 request per minute
|
||||
"""
|
||||
import requests
|
||||
from openai import OpenAI, RateLimitError
|
||||
|
||||
def create_key(api_key: str, base_url: str):
|
||||
response = requests.post(
|
||||
url="{}/key/generate".format(base_url),
|
||||
json={},
|
||||
headers={
|
||||
"Authorization": "Bearer {}".format(api_key)
|
||||
}
|
||||
)
|
||||
|
||||
_response = response.json()
|
||||
|
||||
return _response["key"]
|
||||
|
||||
key_1 = create_key(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||
key_2 = create_key(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||
|
||||
# call proxy with key 1 - works
|
||||
openai_client_1 = OpenAI(api_key=key_1, base_url="http://0.0.0.0:4000")
|
||||
|
||||
response = openai_client_1.chat.completions.with_raw_response.create(
|
||||
model="my-fake-model", messages=[{"role": "user", "content": "Hello world!"}],
|
||||
)
|
||||
|
||||
print("Headers for call 1 - {}".format(response.headers))
|
||||
_response = response.parse()
|
||||
print("Total tokens for call - {}".format(_response.usage.total_tokens))
|
||||
|
||||
|
||||
# call proxy with key 2 - works
|
||||
openai_client_2 = OpenAI(api_key=key_2, base_url="http://0.0.0.0:4000")
|
||||
|
||||
response = openai_client_2.chat.completions.with_raw_response.create(
|
||||
model="my-fake-model", messages=[{"role": "user", "content": "Hello world!"}],
|
||||
)
|
||||
|
||||
print("Headers for call 2 - {}".format(response.headers))
|
||||
_response = response.parse()
|
||||
print("Total tokens for call - {}".format(_response.usage.total_tokens))
|
||||
# call proxy with key 2 - fails
|
||||
try:
|
||||
openai_client_2.chat.completions.with_raw_response.create(model="my-fake-model", messages=[{"role": "user", "content": "Hey, how's it going?"}])
|
||||
raise Exception("This should have failed!")
|
||||
except RateLimitError as e:
|
||||
print("This was rate limited b/c - {}".format(str(e)))
|
||||
|
||||
```
|
||||
|
||||
**Expected Response**
|
||||
|
||||
```
|
||||
This was rate limited b/c - Error code: 429 - {'error': {'message': {'error': 'Key=<hashed_token> over available TPM=0. Model TPM=0, Active keys=2'}, 'type': 'None', 'param': 'None', 'code': 429}}
|
||||
```
|
|
@ -37,7 +37,9 @@ input_callback: List[Union[str, Callable]] = []
|
|||
success_callback: List[Union[str, Callable]] = []
|
||||
failure_callback: List[Union[str, Callable]] = []
|
||||
service_callback: List[Union[str, Callable]] = []
|
||||
_custom_logger_compatible_callbacks_literal = Literal["lago", "openmeter", "logfire"]
|
||||
_custom_logger_compatible_callbacks_literal = Literal[
|
||||
"lago", "openmeter", "logfire", "dynamic_rate_limiter"
|
||||
]
|
||||
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
|
||||
_langfuse_default_tags: Optional[
|
||||
List[
|
||||
|
@ -735,6 +737,7 @@ from .utils import (
|
|||
client,
|
||||
exception_type,
|
||||
get_optional_params,
|
||||
get_response_string,
|
||||
modify_integration,
|
||||
token_counter,
|
||||
create_pretrained_tokenizer,
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from .types.services import ServiceTypes, ServiceLoggerPayload
|
||||
from .integrations.prometheus_services import PrometheusServicesLogger
|
||||
|
||||
from .integrations.custom_logger import CustomLogger
|
||||
from datetime import timedelta
|
||||
from typing import Union, Optional, TYPE_CHECKING, Any
|
||||
from .integrations.prometheus_services import PrometheusServicesLogger
|
||||
from .types.services import ServiceLoggerPayload, ServiceTypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
@ -53,8 +54,8 @@ class ServiceLogging(CustomLogger):
|
|||
call_type: str,
|
||||
duration: float,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
start_time: Optional[Union[datetime, float]] = None,
|
||||
end_time: Optional[Union[datetime, float]] = None,
|
||||
):
|
||||
"""
|
||||
- For counting if the redis, postgres call is successful
|
||||
|
@ -92,8 +93,8 @@ class ServiceLogging(CustomLogger):
|
|||
error: Union[str, Exception],
|
||||
call_type: str,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
start_time: Optional[Union[datetime, float]] = None,
|
||||
end_time: Optional[Union[float, datetime]] = None,
|
||||
):
|
||||
"""
|
||||
- For counting if the redis, postgres call is unsuccessful
|
||||
|
|
|
@ -14,6 +14,7 @@ import json
|
|||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from datetime import timedelta
|
||||
from typing import Any, BinaryIO, List, Literal, Optional, Union
|
||||
|
||||
from openai._models import BaseModel as OpenAIObject
|
||||
|
@ -92,9 +93,22 @@ class InMemoryCache(BaseCache):
|
|||
else:
|
||||
self.set_cache(key=cache_key, value=cache_value)
|
||||
|
||||
|
||||
if time.time() - self.last_cleaned > self.default_ttl:
|
||||
asyncio.create_task(self.clean_up_in_memory_cache())
|
||||
|
||||
async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float]):
|
||||
"""
|
||||
Add value to set
|
||||
"""
|
||||
# get the value
|
||||
init_value = self.get_cache(key=key) or set()
|
||||
for val in value:
|
||||
init_value.add(val)
|
||||
self.set_cache(key, init_value, ttl=ttl)
|
||||
return value
|
||||
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
if key in self.cache_dict:
|
||||
if key in self.ttl_dict:
|
||||
|
@ -363,6 +377,7 @@ class RedisCache(BaseCache):
|
|||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
||||
call_type="async_set_cache",
|
||||
)
|
||||
)
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
|
@ -482,6 +497,80 @@ class RedisCache(BaseCache):
|
|||
cache_value,
|
||||
)
|
||||
|
||||
async def async_set_cache_sadd(
|
||||
self, key, value: List, ttl: Optional[float], **kwargs
|
||||
):
|
||||
start_time = time.time()
|
||||
try:
|
||||
_redis_client = self.init_async_client()
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
error=e,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
||||
call_type="async_set_cache_sadd",
|
||||
)
|
||||
)
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
verbose_logger.error(
|
||||
"LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
|
||||
str(e),
|
||||
value,
|
||||
)
|
||||
raise e
|
||||
|
||||
key = self.check_and_fix_namespace(key=key)
|
||||
async with _redis_client as redis_client:
|
||||
print_verbose(
|
||||
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
||||
)
|
||||
try:
|
||||
await redis_client.sadd(key, *value)
|
||||
if ttl is not None:
|
||||
_td = timedelta(seconds=ttl)
|
||||
await redis_client.expire(key, _td)
|
||||
print_verbose(
|
||||
f"Successfully Set ASYNC Redis Cache SADD: key: {key}\nValue {value}\nttl={ttl}"
|
||||
)
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
call_type="async_set_cache_sadd",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.service_logger_obj.async_service_failure_hook(
|
||||
service=ServiceTypes.REDIS,
|
||||
duration=_duration,
|
||||
error=e,
|
||||
call_type="async_set_cache_sadd",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
||||
)
|
||||
)
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
verbose_logger.error(
|
||||
"LiteLLM Redis Caching: async set_cache_sadd() - Got exception from REDIS %s, Writing value=%s",
|
||||
str(e),
|
||||
value,
|
||||
)
|
||||
|
||||
async def batch_cache_write(self, key, value, **kwargs):
|
||||
print_verbose(
|
||||
f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}",
|
||||
|
@ -1506,7 +1595,7 @@ class DualCache(BaseCache):
|
|||
key, value, **kwargs
|
||||
)
|
||||
|
||||
if self.redis_cache is not None and local_only == False:
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
result = await self.redis_cache.async_increment(key, value, **kwargs)
|
||||
|
||||
return result
|
||||
|
@ -1515,6 +1604,38 @@ class DualCache(BaseCache):
|
|||
verbose_logger.debug(traceback.format_exc())
|
||||
raise e
|
||||
|
||||
async def async_set_cache_sadd(
|
||||
self, key, value: List, local_only: bool = False, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Add value to a set
|
||||
|
||||
Key - the key in cache
|
||||
|
||||
Value - str - the value you want to add to the set
|
||||
|
||||
Returns - None
|
||||
"""
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
_ = await self.in_memory_cache.async_set_cache_sadd(
|
||||
key, value, ttl=kwargs.get("ttl", None)
|
||||
)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
_ = await self.redis_cache.async_set_cache_sadd(
|
||||
key, value, ttl=kwargs.get("ttl", None) ** kwargs
|
||||
)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
"LiteLLM Cache: Excepton async set_cache_sadd: {}\n{}".format(
|
||||
str(e), traceback.format_exc()
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
def flush_cache(self):
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.flush_cache()
|
||||
|
|
|
@ -105,8 +105,8 @@ class OpenTelemetry(CustomLogger):
|
|||
self,
|
||||
payload: ServiceLoggerPayload,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
start_time: Optional[Union[datetime, float]] = None,
|
||||
end_time: Optional[Union[datetime, float]] = None,
|
||||
):
|
||||
from datetime import datetime
|
||||
|
||||
|
@ -144,8 +144,8 @@ class OpenTelemetry(CustomLogger):
|
|||
self,
|
||||
payload: ServiceLoggerPayload,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
start_time: Optional[Union[datetime, float]] = None,
|
||||
end_time: Optional[Union[float, datetime]] = None,
|
||||
):
|
||||
from datetime import datetime
|
||||
|
||||
|
|
|
@ -19,7 +19,8 @@ from litellm import (
|
|||
turn_off_message_logging,
|
||||
verbose_logger,
|
||||
)
|
||||
from litellm.caching import InMemoryCache, S3Cache
|
||||
|
||||
from litellm.caching import InMemoryCache, S3Cache, DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.redact_messages import (
|
||||
redact_message_input_output_from_logging,
|
||||
|
@ -1899,7 +1900,11 @@ def set_callbacks(callback_list, function_id=None):
|
|||
|
||||
def _init_custom_logger_compatible_class(
|
||||
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
|
||||
) -> Callable:
|
||||
internal_usage_cache: Optional[DualCache],
|
||||
llm_router: Optional[
|
||||
Any
|
||||
], # expect litellm.Router, but typing errors due to circular import
|
||||
) -> CustomLogger:
|
||||
if logging_integration == "lago":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, LagoLogger):
|
||||
|
@ -1935,3 +1940,58 @@ def _init_custom_logger_compatible_class(
|
|||
_otel_logger = OpenTelemetry(config=otel_config)
|
||||
_in_memory_loggers.append(_otel_logger)
|
||||
return _otel_logger # type: ignore
|
||||
elif logging_integration == "dynamic_rate_limiter":
|
||||
from litellm.proxy.hooks.dynamic_rate_limiter import (
|
||||
_PROXY_DynamicRateLimitHandler,
|
||||
)
|
||||
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, _PROXY_DynamicRateLimitHandler):
|
||||
return callback # type: ignore
|
||||
|
||||
if internal_usage_cache is None:
|
||||
raise Exception(
|
||||
"Internal Error: Cache cannot be empty - internal_usage_cache={}".format(
|
||||
internal_usage_cache
|
||||
)
|
||||
)
|
||||
|
||||
dynamic_rate_limiter_obj = _PROXY_DynamicRateLimitHandler(
|
||||
internal_usage_cache=internal_usage_cache
|
||||
)
|
||||
|
||||
if llm_router is not None and isinstance(llm_router, litellm.Router):
|
||||
dynamic_rate_limiter_obj.update_variables(llm_router=llm_router)
|
||||
_in_memory_loggers.append(dynamic_rate_limiter_obj)
|
||||
return dynamic_rate_limiter_obj # type: ignore
|
||||
|
||||
|
||||
def get_custom_logger_compatible_class(
|
||||
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
|
||||
) -> Optional[CustomLogger]:
|
||||
if logging_integration == "lago":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, LagoLogger):
|
||||
return callback
|
||||
elif logging_integration == "openmeter":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, OpenMeterLogger):
|
||||
return callback
|
||||
elif logging_integration == "logfire":
|
||||
if "LOGFIRE_TOKEN" not in os.environ:
|
||||
raise ValueError("LOGFIRE_TOKEN not found in environment variables")
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, OpenTelemetry):
|
||||
return callback # type: ignore
|
||||
|
||||
elif logging_integration == "dynamic_rate_limiter":
|
||||
from litellm.proxy.hooks.dynamic_rate_limiter import (
|
||||
_PROXY_DynamicRateLimitHandler,
|
||||
)
|
||||
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, _PROXY_DynamicRateLimitHandler):
|
||||
return callback # type: ignore
|
||||
return None
|
||||
|
|
|
@ -428,7 +428,7 @@ def mock_completion(
|
|||
model: str,
|
||||
messages: List,
|
||||
stream: Optional[bool] = False,
|
||||
mock_response: Union[str, Exception] = "This is a mock request",
|
||||
mock_response: Union[str, Exception, dict] = "This is a mock request",
|
||||
mock_tool_calls: Optional[List] = None,
|
||||
logging=None,
|
||||
custom_llm_provider=None,
|
||||
|
@ -477,6 +477,9 @@ def mock_completion(
|
|||
if time_delay is not None:
|
||||
time.sleep(time_delay)
|
||||
|
||||
if isinstance(mock_response, dict):
|
||||
return ModelResponse(**mock_response)
|
||||
|
||||
model_response = ModelResponse(stream=stream)
|
||||
if stream is True:
|
||||
# don't try to access stream object,
|
||||
|
|
|
@ -1,61 +1,10 @@
|
|||
environment_variables:
|
||||
LANGFUSE_PUBLIC_KEY: Q6K8MQN6L7sPYSJiFKM9eNrETOx6V/FxVPup4FqdKsZK1hyR4gyanlQ2KHLg5D5afng99uIt0JCEQ2jiKF9UxFvtnb4BbJ4qpeceH+iK8v/bdg==
|
||||
LANGFUSE_SECRET_KEY: 5xQ7KMa6YMLsm+H/Pf1VmlqWq1NON5IoCxABhkUBeSck7ftsj2CmpkL2ZwrxwrktgiTUBH+3gJYBX+XBk7lqOOUpvmiLjol/E5lCqq0M1CqLWA==
|
||||
SLACK_WEBHOOK_URL: RJjhS0Hhz0/s07sCIf1OTXmTGodpK9L2K9p953Z+fOX0l2SkPFT6mB9+yIrLufmlwEaku5NNEBKy//+AG01yOd+7wV1GhK65vfj3B/gTN8t5cuVnR4vFxKY5Rx4eSGLtzyAs+aIBTp4GoNXDIjroCqfCjPkItEZWCg==
|
||||
general_settings:
|
||||
alerting:
|
||||
- slack
|
||||
alerting_threshold: 300
|
||||
database_connection_pool_limit: 100
|
||||
database_connection_timeout: 60
|
||||
disable_master_key_return: true
|
||||
health_check_interval: 300
|
||||
proxy_batch_write_at: 60
|
||||
ui_access_mode: all
|
||||
# master_key: sk-1234
|
||||
litellm_settings:
|
||||
allowed_fails: 3
|
||||
failure_callback:
|
||||
- prometheus
|
||||
num_retries: 3
|
||||
service_callback:
|
||||
- prometheus_system
|
||||
success_callback:
|
||||
- langfuse
|
||||
- prometheus
|
||||
- langsmith
|
||||
model_list:
|
||||
- litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
model_name: gpt-3.5-turbo
|
||||
- litellm_params:
|
||||
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
|
||||
api_key: my-fake-key
|
||||
model: openai/my-fake-model
|
||||
stream_timeout: 0.001
|
||||
model_name: fake-openai-endpoint
|
||||
- litellm_params:
|
||||
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
|
||||
api_key: my-fake-key
|
||||
model: openai/my-fake-model-2
|
||||
stream_timeout: 0.001
|
||||
model_name: fake-openai-endpoint
|
||||
- litellm_params:
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: 2023-07-01-preview
|
||||
model: azure/chatgpt-v-2
|
||||
stream_timeout: 0.001
|
||||
model_name: azure-gpt-3.5
|
||||
- litellm_params:
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model: text-embedding-ada-002
|
||||
model_name: text-embedding-ada-002
|
||||
- litellm_params:
|
||||
model: text-completion-openai/gpt-3.5-turbo-instruct
|
||||
model_name: gpt-instruct
|
||||
router_settings:
|
||||
enable_pre_call_checks: true
|
||||
redis_host: os.environ/REDIS_HOST
|
||||
redis_password: os.environ/REDIS_PASSWORD
|
||||
redis_port: os.environ/REDIS_PORT
|
||||
model_list:
|
||||
- model_name: my-fake-model
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
api_key: my-fake-key
|
||||
mock_response: hello-world
|
||||
tpm: 60
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["dynamic_rate_limiter"]
|
|
@ -30,6 +30,7 @@ model_list:
|
|||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: 2024-02-15-preview
|
||||
model: azure/chatgpt-v-2
|
||||
tpm: 100
|
||||
model_name: gpt-3.5-turbo
|
||||
- litellm_params:
|
||||
model: anthropic.claude-3-sonnet-20240229-v1:0
|
||||
|
@ -40,6 +41,7 @@ model_list:
|
|||
api_version: 2024-02-15-preview
|
||||
model: azure/chatgpt-v-2
|
||||
drop_params: True
|
||||
tpm: 100
|
||||
model_name: gpt-3.5-turbo
|
||||
- model_name: tts
|
||||
litellm_params:
|
||||
|
@ -67,8 +69,7 @@ model_list:
|
|||
max_input_tokens: 80920
|
||||
|
||||
litellm_settings:
|
||||
success_callback: ["langfuse"]
|
||||
failure_callback: ["langfuse"]
|
||||
callbacks: ["dynamic_rate_limiter"]
|
||||
# default_team_settings:
|
||||
# - team_id: proj1
|
||||
# success_callback: ["langfuse"]
|
||||
|
|
205
litellm/proxy/hooks/dynamic_rate_limiter.py
Normal file
205
litellm/proxy/hooks/dynamic_rate_limiter.py
Normal file
|
@ -0,0 +1,205 @@
|
|||
# What is this?
|
||||
## Allocates dynamic tpm/rpm quota for a project based on current traffic
|
||||
## Tracks num active projects per minute
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm import ModelResponse, Router
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.router import ModelGroupInfo
|
||||
from litellm.utils import get_utc_datetime
|
||||
|
||||
|
||||
class DynamicRateLimiterCache:
|
||||
"""
|
||||
Thin wrapper on DualCache for this file.
|
||||
|
||||
Track number of active projects calling a model.
|
||||
"""
|
||||
|
||||
def __init__(self, cache: DualCache) -> None:
|
||||
self.cache = cache
|
||||
self.ttl = 60 # 1 min ttl
|
||||
|
||||
async def async_get_cache(self, model: str) -> Optional[int]:
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
key_name = "{}:{}".format(current_minute, model)
|
||||
_response = await self.cache.async_get_cache(key=key_name)
|
||||
response: Optional[int] = None
|
||||
if _response is not None:
|
||||
response = len(_response)
|
||||
return response
|
||||
|
||||
async def async_set_cache_sadd(self, model: str, value: List):
|
||||
"""
|
||||
Add value to set.
|
||||
|
||||
Parameters:
|
||||
- model: str, the name of the model group
|
||||
- value: str, the team id
|
||||
|
||||
Returns:
|
||||
- None
|
||||
|
||||
Raises:
|
||||
- Exception, if unable to connect to cache client (if redis caching enabled)
|
||||
"""
|
||||
try:
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
|
||||
key_name = "{}:{}".format(current_minute, model)
|
||||
await self.cache.async_set_cache_sadd(
|
||||
key=key_name, value=value, ttl=self.ttl
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.hooks.dynamic_rate_limiter.py::async_set_cache_sadd(): Exception occured - {}\n{}".format(
|
||||
str(e), traceback.format_exc()
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
class _PROXY_DynamicRateLimitHandler(CustomLogger):
|
||||
|
||||
# Class variables or attributes
|
||||
def __init__(self, internal_usage_cache: DualCache):
|
||||
self.internal_usage_cache = DynamicRateLimiterCache(cache=internal_usage_cache)
|
||||
|
||||
def update_variables(self, llm_router: Router):
|
||||
self.llm_router = llm_router
|
||||
|
||||
async def check_available_tpm(
|
||||
self, model: str
|
||||
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
|
||||
"""
|
||||
For a given model, get its available tpm
|
||||
|
||||
Returns
|
||||
- Tuple[available_tpm, model_tpm, active_projects]
|
||||
- available_tpm: int or null - always 0 or positive.
|
||||
- remaining_model_tpm: int or null. If available tpm is int, then this will be too.
|
||||
- active_projects: int or null
|
||||
"""
|
||||
active_projects = await self.internal_usage_cache.async_get_cache(model=model)
|
||||
current_model_tpm: Optional[int] = await self.llm_router.get_model_group_usage(
|
||||
model_group=model
|
||||
)
|
||||
model_group_info: Optional[ModelGroupInfo] = (
|
||||
self.llm_router.get_model_group_info(model_group=model)
|
||||
)
|
||||
total_model_tpm: Optional[int] = None
|
||||
if model_group_info is not None and model_group_info.tpm is not None:
|
||||
total_model_tpm = model_group_info.tpm
|
||||
|
||||
remaining_model_tpm: Optional[int] = None
|
||||
if total_model_tpm is not None and current_model_tpm is not None:
|
||||
remaining_model_tpm = total_model_tpm - current_model_tpm
|
||||
elif total_model_tpm is not None:
|
||||
remaining_model_tpm = total_model_tpm
|
||||
|
||||
available_tpm: Optional[int] = None
|
||||
|
||||
if remaining_model_tpm is not None:
|
||||
if active_projects is not None:
|
||||
available_tpm = int(remaining_model_tpm / active_projects)
|
||||
else:
|
||||
available_tpm = remaining_model_tpm
|
||||
|
||||
if available_tpm is not None and available_tpm < 0:
|
||||
available_tpm = 0
|
||||
return available_tpm, remaining_model_tpm, active_projects
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
],
|
||||
) -> Optional[
|
||||
Union[Exception, str, dict]
|
||||
]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
|
||||
"""
|
||||
- For a model group
|
||||
- Check if tpm available
|
||||
- Raise RateLimitError if no tpm available
|
||||
"""
|
||||
if "model" in data:
|
||||
available_tpm, model_tpm, active_projects = await self.check_available_tpm(
|
||||
model=data["model"]
|
||||
)
|
||||
if available_tpm is not None and available_tpm == 0:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
"error": "Key={} over available TPM={}. Model TPM={}, Active keys={}".format(
|
||||
user_api_key_dict.api_key,
|
||||
available_tpm,
|
||||
model_tpm,
|
||||
active_projects,
|
||||
)
|
||||
},
|
||||
)
|
||||
elif available_tpm is not None:
|
||||
## UPDATE CACHE WITH ACTIVE PROJECT
|
||||
asyncio.create_task(
|
||||
self.internal_usage_cache.async_set_cache_sadd( # this is a set
|
||||
model=data["model"], # type: ignore
|
||||
value=[user_api_key_dict.token or "default_key"],
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self, user_api_key_dict: UserAPIKeyAuth, response
|
||||
):
|
||||
try:
|
||||
if isinstance(response, ModelResponse):
|
||||
model_info = self.llm_router.get_model_info(
|
||||
id=response._hidden_params["model_id"]
|
||||
)
|
||||
assert (
|
||||
model_info is not None
|
||||
), "Model info for model with id={} is None".format(
|
||||
response._hidden_params["model_id"]
|
||||
)
|
||||
available_tpm, remaining_model_tpm, active_projects = (
|
||||
await self.check_available_tpm(model=model_info["model_name"])
|
||||
)
|
||||
response._hidden_params["additional_headers"] = {
|
||||
"x-litellm-model_group": model_info["model_name"],
|
||||
"x-ratelimit-remaining-litellm-project-tokens": available_tpm,
|
||||
"x-ratelimit-remaining-model-tokens": remaining_model_tpm,
|
||||
"x-ratelimit-current-active-projects": active_projects,
|
||||
}
|
||||
|
||||
return response
|
||||
return await super().async_post_call_success_hook(
|
||||
user_api_key_dict, response
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.hooks.dynamic_rate_limiter.py::async_post_call_success_hook(): Exception occured - {}\n{}".format(
|
||||
str(e), traceback.format_exc()
|
||||
)
|
||||
)
|
||||
return response
|
|
@ -433,6 +433,7 @@ def get_custom_headers(
|
|||
version: Optional[str] = None,
|
||||
model_region: Optional[str] = None,
|
||||
fastest_response_batch_completion: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
exclude_values = {"", None}
|
||||
headers = {
|
||||
|
@ -448,6 +449,7 @@ def get_custom_headers(
|
|||
if fastest_response_batch_completion is not None
|
||||
else None
|
||||
),
|
||||
**{k: str(v) for k, v in kwargs.items()},
|
||||
}
|
||||
try:
|
||||
return {
|
||||
|
@ -2644,7 +2646,9 @@ async def startup_event():
|
|||
redis_cache=redis_usage_cache
|
||||
) # used by parallel request limiter for rate limiting keys across instances
|
||||
|
||||
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||
proxy_logging_obj._init_litellm_callbacks(
|
||||
llm_router=llm_router
|
||||
) # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||
|
||||
if "daily_reports" in proxy_logging_obj.slack_alerting_instance.alert_types:
|
||||
asyncio.create_task(
|
||||
|
@ -3061,6 +3065,14 @@ async def chat_completion(
|
|||
headers=custom_headers,
|
||||
)
|
||||
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
|
@ -3070,14 +3082,10 @@ async def chat_completion(
|
|||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
fastest_response_batch_completion=fastest_response_batch_completion,
|
||||
**additional_headers,
|
||||
)
|
||||
)
|
||||
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
|
||||
return response
|
||||
except RejectedRequestError as e:
|
||||
_data = e.request_data
|
||||
|
@ -3116,11 +3124,10 @@ async def chat_completion(
|
|||
except Exception as e:
|
||||
data["litellm_status"] = "fail" # used for alerting
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.chat_completion(): Exception occured - {}".format(
|
||||
get_error_message_str(e=e)
|
||||
"litellm.proxy.proxy_server.chat_completion(): Exception occured - {}\n{}".format(
|
||||
get_error_message_str(e=e), traceback.format_exc()
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
|
|
|
@ -229,31 +229,32 @@ class ProxyLogging:
|
|||
if redis_cache is not None:
|
||||
self.internal_usage_cache.redis_cache = redis_cache
|
||||
|
||||
def _init_litellm_callbacks(self):
|
||||
print_verbose("INITIALIZING LITELLM CALLBACKS!")
|
||||
def _init_litellm_callbacks(self, llm_router: Optional[litellm.Router] = None):
|
||||
self.service_logging_obj = ServiceLogging()
|
||||
litellm.callbacks.append(self.max_parallel_request_limiter)
|
||||
litellm.callbacks.append(self.max_budget_limiter)
|
||||
litellm.callbacks.append(self.cache_control_check)
|
||||
litellm.callbacks.append(self.service_logging_obj)
|
||||
litellm.callbacks.append(self.max_parallel_request_limiter) # type: ignore
|
||||
litellm.callbacks.append(self.max_budget_limiter) # type: ignore
|
||||
litellm.callbacks.append(self.cache_control_check) # type: ignore
|
||||
litellm.callbacks.append(self.service_logging_obj) # type: ignore
|
||||
litellm.success_callback.append(
|
||||
self.slack_alerting_instance.response_taking_too_long_callback
|
||||
)
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, str):
|
||||
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class(
|
||||
callback
|
||||
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
|
||||
callback,
|
||||
internal_usage_cache=self.internal_usage_cache,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
if callback not in litellm.input_callback:
|
||||
litellm.input_callback.append(callback)
|
||||
litellm.input_callback.append(callback) # type: ignore
|
||||
if callback not in litellm.success_callback:
|
||||
litellm.success_callback.append(callback)
|
||||
litellm.success_callback.append(callback) # type: ignore
|
||||
if callback not in litellm.failure_callback:
|
||||
litellm.failure_callback.append(callback)
|
||||
litellm.failure_callback.append(callback) # type: ignore
|
||||
if callback not in litellm._async_success_callback:
|
||||
litellm._async_success_callback.append(callback)
|
||||
litellm._async_success_callback.append(callback) # type: ignore
|
||||
if callback not in litellm._async_failure_callback:
|
||||
litellm._async_failure_callback.append(callback)
|
||||
litellm._async_failure_callback.append(callback) # type: ignore
|
||||
|
||||
if (
|
||||
len(litellm.input_callback) > 0
|
||||
|
@ -301,10 +302,19 @@ class ProxyLogging:
|
|||
|
||||
try:
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars(
|
||||
callback.__class__
|
||||
_callback: Optional[CustomLogger] = None
|
||||
if isinstance(callback, str):
|
||||
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||||
callback
|
||||
)
|
||||
else:
|
||||
_callback = callback # type: ignore
|
||||
if (
|
||||
_callback is not None
|
||||
and isinstance(_callback, CustomLogger)
|
||||
and "async_pre_call_hook" in vars(_callback.__class__)
|
||||
):
|
||||
response = await callback.async_pre_call_hook(
|
||||
response = await _callback.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=self.call_details["user_api_key_cache"],
|
||||
data=data,
|
||||
|
@ -574,8 +584,15 @@ class ProxyLogging:
|
|||
|
||||
for callback in litellm.callbacks:
|
||||
try:
|
||||
if isinstance(callback, CustomLogger):
|
||||
await callback.async_post_call_failure_hook(
|
||||
_callback: Optional[CustomLogger] = None
|
||||
if isinstance(callback, str):
|
||||
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||||
callback
|
||||
)
|
||||
else:
|
||||
_callback = callback # type: ignore
|
||||
if _callback is not None and isinstance(_callback, CustomLogger):
|
||||
await _callback.async_post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=original_exception,
|
||||
)
|
||||
|
@ -596,8 +613,15 @@ class ProxyLogging:
|
|||
"""
|
||||
for callback in litellm.callbacks:
|
||||
try:
|
||||
if isinstance(callback, CustomLogger):
|
||||
await callback.async_post_call_success_hook(
|
||||
_callback: Optional[CustomLogger] = None
|
||||
if isinstance(callback, str):
|
||||
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||||
callback
|
||||
)
|
||||
else:
|
||||
_callback = callback # type: ignore
|
||||
if _callback is not None and isinstance(_callback, CustomLogger):
|
||||
await _callback.async_post_call_success_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
except Exception as e:
|
||||
|
@ -615,14 +639,25 @@ class ProxyLogging:
|
|||
Covers:
|
||||
1. /chat/completions
|
||||
"""
|
||||
for callback in litellm.callbacks:
|
||||
try:
|
||||
if isinstance(callback, CustomLogger):
|
||||
await callback.async_post_call_streaming_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
response_str: Optional[str] = None
|
||||
if isinstance(response, ModelResponse):
|
||||
response_str = litellm.get_response_string(response_obj=response)
|
||||
if response_str is not None:
|
||||
for callback in litellm.callbacks:
|
||||
try:
|
||||
_callback: Optional[CustomLogger] = None
|
||||
if isinstance(callback, str):
|
||||
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||||
callback
|
||||
)
|
||||
else:
|
||||
_callback = callback # type: ignore
|
||||
if _callback is not None and isinstance(_callback, CustomLogger):
|
||||
await _callback.async_post_call_streaming_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=response_str
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
return response
|
||||
|
||||
async def post_call_streaming_hook(
|
||||
|
|
|
@ -11,6 +11,7 @@ import asyncio
|
|||
import concurrent
|
||||
import copy
|
||||
import datetime as datetime_og
|
||||
import enum
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
|
@ -90,6 +91,10 @@ from litellm.utils import (
|
|||
)
|
||||
|
||||
|
||||
class RoutingArgs(enum.Enum):
|
||||
ttl = 60 # 1min (RPM/TPM expire key)
|
||||
|
||||
|
||||
class Router:
|
||||
model_names: List = []
|
||||
cache_responses: Optional[bool] = False
|
||||
|
@ -387,6 +392,11 @@ class Router:
|
|||
routing_strategy=routing_strategy,
|
||||
routing_strategy_args=routing_strategy_args,
|
||||
)
|
||||
## USAGE TRACKING ##
|
||||
if isinstance(litellm._async_success_callback, list):
|
||||
litellm._async_success_callback.append(self.deployment_callback_on_success)
|
||||
else:
|
||||
litellm._async_success_callback.append(self.deployment_callback_on_success)
|
||||
## COOLDOWNS ##
|
||||
if isinstance(litellm.failure_callback, list):
|
||||
litellm.failure_callback.append(self.deployment_callback_on_failure)
|
||||
|
@ -2640,13 +2650,69 @@ class Router:
|
|||
time.sleep(_timeout)
|
||||
|
||||
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
|
||||
original_exception.max_retries = num_retries
|
||||
original_exception.num_retries = current_attempt
|
||||
setattr(original_exception, "max_retries", num_retries)
|
||||
setattr(original_exception, "num_retries", current_attempt)
|
||||
|
||||
raise original_exception
|
||||
|
||||
### HELPER FUNCTIONS
|
||||
|
||||
async def deployment_callback_on_success(
|
||||
self,
|
||||
kwargs, # kwargs to completion
|
||||
completion_response, # response from completion
|
||||
start_time,
|
||||
end_time, # start/end time
|
||||
):
|
||||
"""
|
||||
Track remaining tpm/rpm quota for model in model_list
|
||||
"""
|
||||
try:
|
||||
"""
|
||||
Update TPM usage on success
|
||||
"""
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
total_tokens = completion_response["usage"]["total_tokens"]
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime(
|
||||
"%H-%M"
|
||||
) # use the same timezone regardless of system clock
|
||||
|
||||
tpm_key = f"global_router:{id}:tpm:{current_minute}"
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
# update cache
|
||||
|
||||
## TPM
|
||||
await self.cache.async_increment_cache(
|
||||
key=tpm_key, value=total_tokens, ttl=RoutingArgs.ttl.value
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}\n{}".format(
|
||||
str(e), traceback.format_exc()
|
||||
)
|
||||
)
|
||||
pass
|
||||
|
||||
def deployment_callback_on_failure(
|
||||
self,
|
||||
kwargs, # kwargs to completion
|
||||
|
@ -3812,10 +3878,39 @@ class Router:
|
|||
|
||||
model_group_info: Optional[ModelGroupInfo] = None
|
||||
|
||||
total_tpm: Optional[int] = None
|
||||
total_rpm: Optional[int] = None
|
||||
|
||||
for model in self.model_list:
|
||||
if "model_name" in model and model["model_name"] == model_group:
|
||||
# model in model group found #
|
||||
litellm_params = LiteLLM_Params(**model["litellm_params"])
|
||||
# get model tpm
|
||||
_deployment_tpm: Optional[int] = None
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = model.get("tpm", None)
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = model.get("litellm_params", {}).get("tpm", None)
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = model.get("model_info", {}).get("tpm", None)
|
||||
|
||||
if _deployment_tpm is not None:
|
||||
if total_tpm is None:
|
||||
total_tpm = 0
|
||||
total_tpm += _deployment_tpm # type: ignore
|
||||
# get model rpm
|
||||
_deployment_rpm: Optional[int] = None
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = model.get("rpm", None)
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = model.get("litellm_params", {}).get("rpm", None)
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = model.get("model_info", {}).get("rpm", None)
|
||||
|
||||
if _deployment_rpm is not None:
|
||||
if total_rpm is None:
|
||||
total_rpm = 0
|
||||
total_rpm += _deployment_rpm # type: ignore
|
||||
# get model info
|
||||
try:
|
||||
model_info = litellm.get_model_info(model=litellm_params.model)
|
||||
|
@ -3929,8 +4024,44 @@ class Router:
|
|||
"supported_openai_params"
|
||||
]
|
||||
|
||||
## UPDATE WITH TOTAL TPM/RPM FOR MODEL GROUP
|
||||
if total_tpm is not None and model_group_info is not None:
|
||||
model_group_info.tpm = total_tpm
|
||||
|
||||
if total_rpm is not None and model_group_info is not None:
|
||||
model_group_info.rpm = total_rpm
|
||||
|
||||
return model_group_info
|
||||
|
||||
async def get_model_group_usage(self, model_group: str) -> Optional[int]:
|
||||
"""
|
||||
Returns remaining tpm quota for model group
|
||||
"""
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime(
|
||||
"%H-%M"
|
||||
) # use the same timezone regardless of system clock
|
||||
tpm_keys: List[str] = []
|
||||
for model in self.model_list:
|
||||
if "model_name" in model and model["model_name"] == model_group:
|
||||
tpm_keys.append(
|
||||
f"global_router:{model['model_info']['id']}:tpm:{current_minute}"
|
||||
)
|
||||
|
||||
## TPM
|
||||
tpm_usage_list: Optional[List] = await self.cache.async_batch_get_cache(
|
||||
keys=tpm_keys
|
||||
)
|
||||
tpm_usage: Optional[int] = None
|
||||
if tpm_usage_list is not None:
|
||||
for t in tpm_usage_list:
|
||||
if isinstance(t, int):
|
||||
if tpm_usage is None:
|
||||
tpm_usage = 0
|
||||
tpm_usage += t
|
||||
|
||||
return tpm_usage
|
||||
|
||||
def get_model_ids(self) -> List[str]:
|
||||
"""
|
||||
Returns list of model id's.
|
||||
|
@ -4858,7 +4989,7 @@ class Router:
|
|||
def reset(self):
|
||||
## clean up on close
|
||||
litellm.success_callback = []
|
||||
litellm.__async_success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
litellm.failure_callback = []
|
||||
litellm._async_failure_callback = []
|
||||
self.retry_policy = None
|
||||
|
|
486
litellm/tests/test_dynamic_rate_limit_handler.py
Normal file
486
litellm/tests/test_dynamic_rate_limit_handler.py
Normal file
|
@ -0,0 +1,486 @@
|
|||
# What is this?
|
||||
## Unit tests for 'dynamic_rate_limiter.py`
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm import DualCache, Router
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.hooks.dynamic_rate_limiter import (
|
||||
_PROXY_DynamicRateLimitHandler as DynamicRateLimitHandler,
|
||||
)
|
||||
|
||||
"""
|
||||
Basic test cases:
|
||||
|
||||
- If 1 'active' project => give all tpm
|
||||
- If 2 'active' projects => divide tpm in 2
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dynamic_rate_limit_handler() -> DynamicRateLimitHandler:
|
||||
internal_cache = DualCache()
|
||||
return DynamicRateLimitHandler(internal_usage_cache=internal_cache)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response() -> litellm.ModelResponse:
|
||||
return litellm.ModelResponse(
|
||||
**{
|
||||
"id": "chatcmpl-abc123",
|
||||
"object": "chat.completion",
|
||||
"created": 1699896916,
|
||||
"model": "gpt-3.5-turbo-0125",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": '{\n"location": "Boston, MA"\n}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "tool_calls",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_api_key_auth() -> UserAPIKeyAuth:
|
||||
return UserAPIKeyAuth()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_projects", [1, 2, 100])
|
||||
@pytest.mark.asyncio
|
||||
async def test_available_tpm(num_projects, dynamic_rate_limit_handler):
|
||||
model = "my-fake-model"
|
||||
## SET CACHE W/ ACTIVE PROJECTS
|
||||
projects = [str(uuid.uuid4()) for _ in range(num_projects)]
|
||||
|
||||
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
|
||||
model=model, value=projects
|
||||
)
|
||||
|
||||
model_tpm = 100
|
||||
llm_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": model,
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "my-key",
|
||||
"api_base": "my-base",
|
||||
"tpm": model_tpm,
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||
|
||||
## CHECK AVAILABLE TPM PER PROJECT
|
||||
|
||||
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||
model=model
|
||||
)
|
||||
|
||||
expected_availability = int(model_tpm / num_projects)
|
||||
|
||||
assert availability == expected_availability
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_raised(dynamic_rate_limit_handler, user_api_key_auth):
|
||||
"""
|
||||
Unit test. Tests if rate limit error raised when quota exhausted.
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
model = "my-fake-model"
|
||||
## SET CACHE W/ ACTIVE PROJECTS
|
||||
projects = [str(uuid.uuid4())]
|
||||
|
||||
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
|
||||
model=model, value=projects
|
||||
)
|
||||
|
||||
model_tpm = 0
|
||||
llm_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": model,
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "my-key",
|
||||
"api_base": "my-base",
|
||||
"tpm": model_tpm,
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||
|
||||
## CHECK AVAILABLE TPM PER PROJECT
|
||||
|
||||
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||
model=model
|
||||
)
|
||||
|
||||
expected_availability = int(model_tpm / 1)
|
||||
|
||||
assert availability == expected_availability
|
||||
|
||||
## CHECK if exception raised
|
||||
|
||||
try:
|
||||
await dynamic_rate_limit_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_auth,
|
||||
cache=DualCache(),
|
||||
data={"model": model},
|
||||
call_type="completion",
|
||||
)
|
||||
pytest.fail("Expected this to raise HTTPexception")
|
||||
except HTTPException as e:
|
||||
assert e.status_code == 429 # check if rate limit error raised
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_case(dynamic_rate_limit_handler, mock_response):
|
||||
"""
|
||||
If just 1 active project
|
||||
|
||||
it should get all the quota
|
||||
|
||||
= allow request to go through
|
||||
- update token usage
|
||||
- exhaust all tpm with just 1 project
|
||||
- assert ratelimiterror raised at 100%+1 tpm
|
||||
"""
|
||||
model = "my-fake-model"
|
||||
## model tpm - 50
|
||||
model_tpm = 50
|
||||
## tpm per request - 10
|
||||
setattr(
|
||||
mock_response,
|
||||
"usage",
|
||||
litellm.Usage(prompt_tokens=5, completion_tokens=5, total_tokens=10),
|
||||
)
|
||||
|
||||
llm_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": model,
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "my-key",
|
||||
"api_base": "my-base",
|
||||
"tpm": model_tpm,
|
||||
"mock_response": mock_response,
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||
|
||||
prev_availability: Optional[int] = None
|
||||
allowed_fails = 1
|
||||
for _ in range(5):
|
||||
try:
|
||||
# check availability
|
||||
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||
model=model
|
||||
)
|
||||
|
||||
## assert availability updated
|
||||
if prev_availability is not None and availability is not None:
|
||||
assert availability == prev_availability - 10
|
||||
|
||||
print(
|
||||
"prev_availability={}, availability={}".format(
|
||||
prev_availability, availability
|
||||
)
|
||||
)
|
||||
|
||||
prev_availability = availability
|
||||
|
||||
# make call
|
||||
await llm_router.acompletion(
|
||||
model=model, messages=[{"role": "user", "content": "hey!"}]
|
||||
)
|
||||
|
||||
await asyncio.sleep(3)
|
||||
except Exception:
|
||||
if allowed_fails > 0:
|
||||
allowed_fails -= 1
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_cache(
|
||||
dynamic_rate_limit_handler, mock_response, user_api_key_auth
|
||||
):
|
||||
"""
|
||||
Check if active project correctly updated
|
||||
"""
|
||||
model = "my-fake-model"
|
||||
model_tpm = 50
|
||||
|
||||
llm_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": model,
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "my-key",
|
||||
"api_base": "my-base",
|
||||
"tpm": model_tpm,
|
||||
"mock_response": mock_response,
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||
|
||||
## INITIAL ACTIVE PROJECTS - ASSERT NONE
|
||||
_, _, active_projects = await dynamic_rate_limit_handler.check_available_tpm(
|
||||
model=model
|
||||
)
|
||||
|
||||
assert active_projects is None
|
||||
|
||||
## MAKE CALL
|
||||
await dynamic_rate_limit_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_auth,
|
||||
cache=DualCache(),
|
||||
data={"model": model},
|
||||
call_type="completion",
|
||||
)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
## INITIAL ACTIVE PROJECTS - ASSERT 1
|
||||
_, _, active_projects = await dynamic_rate_limit_handler.check_available_tpm(
|
||||
model=model
|
||||
)
|
||||
|
||||
assert active_projects == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_projects", [2])
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_projects(
|
||||
dynamic_rate_limit_handler, mock_response, num_projects
|
||||
):
|
||||
"""
|
||||
If 2 active project
|
||||
|
||||
it should split 50% each
|
||||
|
||||
- assert available tpm is 0 after 50%+1 tpm calls
|
||||
"""
|
||||
model = "my-fake-model"
|
||||
model_tpm = 50
|
||||
total_tokens_per_call = 10
|
||||
step_tokens_per_call_per_project = total_tokens_per_call / num_projects
|
||||
|
||||
available_tpm_per_project = int(model_tpm / num_projects)
|
||||
|
||||
## SET CACHE W/ ACTIVE PROJECTS
|
||||
projects = [str(uuid.uuid4()) for _ in range(num_projects)]
|
||||
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
|
||||
model=model, value=projects
|
||||
)
|
||||
|
||||
expected_runs = int(available_tpm_per_project / step_tokens_per_call_per_project)
|
||||
|
||||
setattr(
|
||||
mock_response,
|
||||
"usage",
|
||||
litellm.Usage(
|
||||
prompt_tokens=5, completion_tokens=5, total_tokens=total_tokens_per_call
|
||||
),
|
||||
)
|
||||
|
||||
llm_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": model,
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "my-key",
|
||||
"api_base": "my-base",
|
||||
"tpm": model_tpm,
|
||||
"mock_response": mock_response,
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||
|
||||
prev_availability: Optional[int] = None
|
||||
|
||||
print("expected_runs: {}".format(expected_runs))
|
||||
for i in range(expected_runs + 1):
|
||||
# check availability
|
||||
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||
model=model
|
||||
)
|
||||
|
||||
## assert availability updated
|
||||
if prev_availability is not None and availability is not None:
|
||||
assert (
|
||||
availability == prev_availability - step_tokens_per_call_per_project
|
||||
), "Current Availability: Got={}, Expected={}, Step={}, Tokens per step={}, Initial model tpm={}".format(
|
||||
availability,
|
||||
prev_availability - 10,
|
||||
i,
|
||||
step_tokens_per_call_per_project,
|
||||
model_tpm,
|
||||
)
|
||||
|
||||
print(
|
||||
"prev_availability={}, availability={}".format(
|
||||
prev_availability, availability
|
||||
)
|
||||
)
|
||||
|
||||
prev_availability = availability
|
||||
|
||||
# make call
|
||||
await llm_router.acompletion(
|
||||
model=model, messages=[{"role": "user", "content": "hey!"}]
|
||||
)
|
||||
|
||||
await asyncio.sleep(3)
|
||||
|
||||
# check availability
|
||||
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||
model=model
|
||||
)
|
||||
assert availability == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_projects", [2])
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_projects_e2e(
|
||||
dynamic_rate_limit_handler, mock_response, num_projects
|
||||
):
|
||||
"""
|
||||
2 parallel calls with different keys, same model
|
||||
|
||||
If 2 active project
|
||||
|
||||
it should split 50% each
|
||||
|
||||
- assert available tpm is 0 after 50%+1 tpm calls
|
||||
"""
|
||||
model = "my-fake-model"
|
||||
model_tpm = 50
|
||||
total_tokens_per_call = 10
|
||||
step_tokens_per_call_per_project = total_tokens_per_call / num_projects
|
||||
|
||||
available_tpm_per_project = int(model_tpm / num_projects)
|
||||
|
||||
## SET CACHE W/ ACTIVE PROJECTS
|
||||
projects = [str(uuid.uuid4()) for _ in range(num_projects)]
|
||||
await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd(
|
||||
model=model, value=projects
|
||||
)
|
||||
|
||||
expected_runs = int(available_tpm_per_project / step_tokens_per_call_per_project)
|
||||
|
||||
setattr(
|
||||
mock_response,
|
||||
"usage",
|
||||
litellm.Usage(
|
||||
prompt_tokens=5, completion_tokens=5, total_tokens=total_tokens_per_call
|
||||
),
|
||||
)
|
||||
|
||||
llm_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": model,
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "my-key",
|
||||
"api_base": "my-base",
|
||||
"tpm": model_tpm,
|
||||
"mock_response": mock_response,
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
dynamic_rate_limit_handler.update_variables(llm_router=llm_router)
|
||||
|
||||
prev_availability: Optional[int] = None
|
||||
|
||||
print("expected_runs: {}".format(expected_runs))
|
||||
for i in range(expected_runs + 1):
|
||||
# check availability
|
||||
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||
model=model
|
||||
)
|
||||
|
||||
## assert availability updated
|
||||
if prev_availability is not None and availability is not None:
|
||||
assert (
|
||||
availability == prev_availability - step_tokens_per_call_per_project
|
||||
), "Current Availability: Got={}, Expected={}, Step={}, Tokens per step={}, Initial model tpm={}".format(
|
||||
availability,
|
||||
prev_availability - 10,
|
||||
i,
|
||||
step_tokens_per_call_per_project,
|
||||
model_tpm,
|
||||
)
|
||||
|
||||
print(
|
||||
"prev_availability={}, availability={}".format(
|
||||
prev_availability, availability
|
||||
)
|
||||
)
|
||||
|
||||
prev_availability = availability
|
||||
|
||||
# make call
|
||||
await llm_router.acompletion(
|
||||
model=model, messages=[{"role": "user", "content": "hey!"}]
|
||||
)
|
||||
|
||||
await asyncio.sleep(3)
|
||||
|
||||
# check availability
|
||||
availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm(
|
||||
model=model
|
||||
)
|
||||
assert availability == 0
|
|
@ -1730,3 +1730,99 @@ async def test_router_text_completion_client():
|
|||
print(responses)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response() -> litellm.ModelResponse:
|
||||
return litellm.ModelResponse(
|
||||
**{
|
||||
"id": "chatcmpl-abc123",
|
||||
"object": "chat.completion",
|
||||
"created": 1699896916,
|
||||
"model": "gpt-3.5-turbo-0125",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": '{\n"location": "Boston, MA"\n}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "tool_calls",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_model_usage(mock_response):
|
||||
"""
|
||||
Test if tracking used model tpm works as expected
|
||||
"""
|
||||
model = "my-fake-model"
|
||||
model_tpm = 100
|
||||
setattr(
|
||||
mock_response,
|
||||
"usage",
|
||||
litellm.Usage(prompt_tokens=5, completion_tokens=5, total_tokens=10),
|
||||
)
|
||||
|
||||
print(f"mock_response: {mock_response}")
|
||||
model_tpm = 100
|
||||
llm_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": model,
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "my-key",
|
||||
"api_base": "my-base",
|
||||
"tpm": model_tpm,
|
||||
"mock_response": mock_response,
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
allowed_fails = 1 # allow for changing b/w minutes
|
||||
|
||||
for _ in range(2):
|
||||
try:
|
||||
_ = await llm_router.acompletion(
|
||||
model=model, messages=[{"role": "user", "content": "Hey!"}]
|
||||
)
|
||||
await asyncio.sleep(3)
|
||||
|
||||
initial_usage = await llm_router.get_model_group_usage(model_group=model)
|
||||
|
||||
# completion call - 10 tokens
|
||||
_ = await llm_router.acompletion(
|
||||
model=model, messages=[{"role": "user", "content": "Hey!"}]
|
||||
)
|
||||
|
||||
await asyncio.sleep(3)
|
||||
updated_usage = await llm_router.get_model_group_usage(model_group=model)
|
||||
|
||||
assert updated_usage == initial_usage + 10 # type: ignore
|
||||
break
|
||||
except Exception as e:
|
||||
if allowed_fails > 0:
|
||||
print(
|
||||
f"Decrementing allowed_fails: {allowed_fails}.\nReceived error - {str(e)}"
|
||||
)
|
||||
allowed_fails -= 1
|
||||
else:
|
||||
print(f"allowed_fails: {allowed_fails}")
|
||||
raise e
|
||||
|
|
|
@ -442,6 +442,8 @@ class ModelGroupInfo(BaseModel):
|
|||
"chat", "embedding", "completion", "image_generation", "audio_transcription"
|
||||
]
|
||||
] = Field(default="chat")
|
||||
tpm: Optional[int] = None
|
||||
rpm: Optional[int] = None
|
||||
supports_parallel_function_calling: bool = Field(default=False)
|
||||
supports_vision: bool = Field(default=False)
|
||||
supports_function_calling: bool = Field(default=False)
|
||||
|
|
|
@ -340,14 +340,15 @@ def function_setup(
|
|||
)
|
||||
try:
|
||||
global callback_list, add_breadcrumb, user_logger_fn, Logging
|
||||
|
||||
function_id = kwargs["id"] if "id" in kwargs else None
|
||||
|
||||
if len(litellm.callbacks) > 0:
|
||||
for callback in litellm.callbacks:
|
||||
# check if callback is a string - e.g. "lago", "openmeter"
|
||||
if isinstance(callback, str):
|
||||
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class(
|
||||
callback
|
||||
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
|
||||
callback, internal_usage_cache=None, llm_router=None
|
||||
)
|
||||
if any(
|
||||
isinstance(cb, type(callback))
|
||||
|
@ -3895,12 +3896,16 @@ def get_formatted_prompt(
|
|||
|
||||
|
||||
def get_response_string(response_obj: ModelResponse) -> str:
|
||||
_choices: List[Choices] = response_obj.choices # type: ignore
|
||||
_choices: List[Union[Choices, StreamingChoices]] = response_obj.choices
|
||||
|
||||
response_str = ""
|
||||
for choice in _choices:
|
||||
if choice.message.content is not None:
|
||||
response_str += choice.message.content
|
||||
if isinstance(choice, Choices):
|
||||
if choice.message.content is not None:
|
||||
response_str += choice.message.content
|
||||
elif isinstance(choice, StreamingChoices):
|
||||
if choice.delta.content is not None:
|
||||
response_str += choice.delta.content
|
||||
|
||||
return response_str
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue