forked from phoenix/litellm-mirror
feat(proxy_server.py): check if views exist on proxy server startup +… (#6360)
* feat(proxy_server.py): check if views exist on proxy server startup + refactor startup event logic to <50 LOC * refactor(redis_cache.py): use a default cache value when writing to r… (#6358) * refactor(redis_cache.py): use a default cache value when writing to redis prevent redis from blowing up in high traffic * refactor(redis_cache.py): refactor all cache writes to use self.get_ttl ensures default ttl always used when writing to redis Prevents redis db from blowing up in prod * feat(proxy_cli.py): add new 'log_config' cli param (#6352) * feat(proxy_cli.py): add new 'log_config' cli param Allows passing logging.conf to uvicorn on startup * docs(cli.md): add logging conf to uvicorn cli docs * fix(get_llm_provider_logic.py): fix default api base for litellm_proxy Fixes https://github.com/BerriAI/litellm/issues/6332 * feat(openai_like/embedding): Add support for jina ai embeddings Closes https://github.com/BerriAI/litellm/issues/6337 * docs(deploy.md): update entrypoint.sh filepath post-refactor Fixes outdated docs * feat(prometheus.py): emit time_to_first_token metric on prometheus Closes https://github.com/BerriAI/litellm/issues/6334 * fix(prometheus.py): only emit time to first token metric if stream is True enables more accurate ttft usage * test: handle vertex api instability * fix(get_llm_provider_logic.py): fix import * fix(openai.py): fix deepinfra default api base * fix(anthropic/transformation.py): remove anthropic beta header (#6361) * docs(sidebars.js): add jina ai embedding to docs * docs(sidebars.js): add jina ai to left nav * bump: version 1.50.1 → 1.50.2 * langfuse use helper for get_langfuse_logging_config * Refactor: apply early return (#6369) * (refactor) remove berrispendLogger - unused logging integration (#6363) * fix remove berrispendLogger * remove unused clickhouse logger * fix docs configs.md * (fix) standard logging metadata + add unit testing (#6366) * fix setting StandardLoggingMetadata * add unit testing for standard logging metadata * fix otel logging test * fix linting * fix typing * Revert "(fix) standard logging metadata + add unit testing (#6366)" (#6381) This reverts commit8359cb6fa9
. * add new 35 mode lcard (#6378) * Add claude 3 5 sonnet 20241022 models for all provides (#6380) * Add Claude 3.5 v2 on Amazon Bedrock and Vertex AI. * added anthropic/claude-3-5-sonnet-20241022 * add new 35 mode lcard --------- Co-authored-by: Paul Gauthier <paul@paulg.com> Co-authored-by: lowjiansheng <15527690+lowjiansheng@users.noreply.github.com> * test(skip-flaky-google-context-caching-test): google is not reliable. their sample code is also not working * test(test_alangfuse.py): handle flaky langfuse test better * (feat) Arize - Allow using Arize HTTP endpoint (#6364) * arize use helper for get_arize_opentelemetry_config * use helper to get Arize OTEL config * arize add helpers for arize * docs allow using arize http endpoint * fix importing OTEL for Arize * use static methods for ArizeLogger * fix ArizeLogger tests * Litellm dev 10 22 2024 (#6384) * fix(utils.py): add 'disallowed_special' for token counting on .encode() Fixes error when '< endoftext >' in string * Revert "(fix) standard logging metadata + add unit testing (#6366)" (#6381) This reverts commit8359cb6fa9
. * add new 35 mode lcard (#6378) * Add claude 3 5 sonnet 20241022 models for all provides (#6380) * Add Claude 3.5 v2 on Amazon Bedrock and Vertex AI. * added anthropic/claude-3-5-sonnet-20241022 * add new 35 mode lcard --------- Co-authored-by: Paul Gauthier <paul@paulg.com> Co-authored-by: lowjiansheng <15527690+lowjiansheng@users.noreply.github.com> * test(skip-flaky-google-context-caching-test): google is not reliable. their sample code is also not working * Fix metadata being overwritten in speech() (#6295) * fix: adding missing redis cluster kwargs (#6318) Co-authored-by: Ali Arian <ali.arian@breadfinancial.com> * Add support for `max_completion_tokens` in Azure OpenAI (#6376) Now that Azure supports `max_completion_tokens`, no need for special handling for this param and let it pass thru. More details: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models?tabs=python-secure#api-support * build(model_prices_and_context_window.json): add voyage-finance-2 pricing Closes https://github.com/BerriAI/litellm/issues/6371 * build(model_prices_and_context_window.json): fix llama3.1 pricing model name on map Closes https://github.com/BerriAI/litellm/issues/6310 * feat(realtime_streaming.py): just log specific events Closes https://github.com/BerriAI/litellm/issues/6267 * fix(utils.py): more robust checking if unmapped vertex anthropic model belongs to that family of models Fixes https://github.com/BerriAI/litellm/issues/6383 * Fix Ollama stream handling for tool calls with None content (#6155) * test(test_max_completions): update test now that azure supports 'max_completion_tokens' * fix(handler.py): fix linting error --------- Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> Co-authored-by: Low Jian Sheng <15527690+lowjiansheng@users.noreply.github.com> Co-authored-by: David Manouchehri <david.manouchehri@ai.moda> Co-authored-by: Paul Gauthier <paul@paulg.com> Co-authored-by: John HU <hszqqq12@gmail.com> Co-authored-by: Ali Arian <113945203+ali-arian@users.noreply.github.com> Co-authored-by: Ali Arian <ali.arian@breadfinancial.com> Co-authored-by: Anand Taralika <46954145+taralika@users.noreply.github.com> Co-authored-by: Nolan Tremelling <34580718+NolanTrem@users.noreply.github.com> * bump: version 1.50.2 → 1.50.3 * build(deps): bump http-proxy-middleware in /docs/my-website (#6395) Bumps [http-proxy-middleware](https://github.com/chimurai/http-proxy-middleware) from 2.0.6 to 2.0.7. - [Release notes](https://github.com/chimurai/http-proxy-middleware/releases) - [Changelog](https://github.com/chimurai/http-proxy-middleware/blob/v2.0.7/CHANGELOG.md) - [Commits](https://github.com/chimurai/http-proxy-middleware/compare/v2.0.6...v2.0.7) --- updated-dependencies: - dependency-name: http-proxy-middleware dependency-type: indirect ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * (docs + testing) Correctly document the timeout value used by litellm proxy is 6000 seconds + add to best practices for prod (#6339) * fix docs use documented timeout * document request timeout * add test for litellm.request_timeout * add test for checking value of timeout * (refactor) move convert dict to model response to llm_response_utils/ (#6393) * refactor move convert dict to model response * fix imports * fix import _handle_invalid_parallel_tool_calls * (refactor) litellm.Router client initialization utils (#6394) * refactor InitalizeOpenAISDKClient * use helper func for _should_create_openai_sdk_client_for_model * use static methods for set client on litellm router * reduce LOC in _get_client_initialization_params * fix _should_create_openai_sdk_client_for_model * code quality fix * test test_should_create_openai_sdk_client_for_model * test test_get_client_initialization_params_openai * fix mypy linting errors * fix OpenAISDKClientInitializationParams * test_get_client_initialization_params_all_env_vars * test_get_client_initialization_params_azure_ai_studio_mistral * test_get_client_initialization_params_default_values * fix _get_client_initialization_params * (fix) Langfuse key based logging (#6372) * langfuse use helper for get_langfuse_logging_config * fix get_langfuse_logger_for_request * fix import * fix get_langfuse_logger_for_request * test_get_langfuse_logger_for_request_with_dynamic_params * unit testing for test_get_langfuse_logger_for_request_with_no_dynamic_params * parameterized langfuse testing * fix langfuse test * fix langfuse logging * fix test_aaalangfuse_logging_metadata * fix langfuse log metadata test * fix langfuse logger * use create_langfuse_logger_from_credentials * fix test_get_langfuse_logger_for_request_with_no_dynamic_params * fix correct langfuse/ folder structure * use static methods for langfuse logger * add commment on langfuse handler * fix linting error * add unit testing for langfuse logging * fix linting * fix failure handler langfuse * Revert "(refactor) litellm.Router client initialization utils (#6394)" (#6403) This reverts commitb70147f63b
. * def test_text_completion_with_echo(stream): (#6401) test * fix linting - remove # noqa PLR0915 from fixed function * test: cleanup codestral tests - backend api unavailable * (refactor) prometheus async_log_success_event to be under 100 LOC (#6416) * unit testig for prometheus * unit testing for success metrics * use 1 helper for _increment_token_metrics * use helper for _increment_remaining_budget_metrics * use _increment_remaining_budget_metrics * use _increment_top_level_request_and_spend_metrics * use helper for _set_latency_metrics * remove noqa violation * fix test prometheus * test prometheus * unit testing for all prometheus helper functions * fix prom unit tests * fix unit tests prometheus * fix unit test prom * (refactor) router - use static methods for client init utils (#6420) * use InitalizeOpenAISDKClient * use InitalizeOpenAISDKClient static method * fix # noqa: PLR0915 * (code cleanup) remove unused and undocumented logging integrations - litedebugger, berrispend (#6406) * code cleanup remove unused and undocumented code files * fix unused logging integrations cleanup * bump: version 1.50.3 → 1.50.4 --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> Co-authored-by: Hakan Taşköprü <Haknt@users.noreply.github.com> Co-authored-by: Low Jian Sheng <15527690+lowjiansheng@users.noreply.github.com> Co-authored-by: David Manouchehri <david.manouchehri@ai.moda> Co-authored-by: Paul Gauthier <paul@paulg.com> Co-authored-by: John HU <hszqqq12@gmail.com> Co-authored-by: Ali Arian <113945203+ali-arian@users.noreply.github.com> Co-authored-by: Ali Arian <ali.arian@breadfinancial.com> Co-authored-by: Anand Taralika <46954145+taralika@users.noreply.github.com> Co-authored-by: Nolan Tremelling <34580718+NolanTrem@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
This commit is contained in:
parent
cc8dd80209
commit
4e310051c7
6 changed files with 271 additions and 142 deletions
|
@ -346,12 +346,14 @@ class PrometheusLogger(CustomLogger):
|
|||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
|
||||
if standard_logging_payload is None or not isinstance(
|
||||
standard_logging_payload, dict
|
||||
):
|
||||
raise ValueError(
|
||||
f"standard_logging_object is required, got={standard_logging_payload}"
|
||||
)
|
||||
|
||||
model = kwargs.get("model", "")
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
_metadata = litellm_params.get("metadata", {})
|
||||
|
|
|
@ -260,6 +260,7 @@ def convert_to_model_response_object( # noqa: PLR0915
|
|||
] = None, # used for supporting 'json_schema' on older models
|
||||
):
|
||||
received_args = locals()
|
||||
|
||||
additional_headers = get_response_headers(_response_headers)
|
||||
|
||||
if hidden_params is None:
|
||||
|
@ -448,11 +449,13 @@ def convert_to_model_response_object( # noqa: PLR0915
|
|||
):
|
||||
if response_object is None:
|
||||
raise Exception("Error in response object format")
|
||||
|
||||
return LiteLLMResponseObjectHandler.convert_to_image_response(
|
||||
response_object=response_object,
|
||||
model_response_object=model_response_object,
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
|
||||
elif response_type == "audio_transcription" and (
|
||||
model_response_object is None
|
||||
or isinstance(model_response_object, TranscriptionResponse)
|
||||
|
|
|
@ -48,4 +48,3 @@ router_settings:
|
|||
redis_host: os.environ/REDIS_HOST
|
||||
redis_port: os.environ/REDIS_PORT
|
||||
redis_password: os.environ/REDIS_PASSWORD
|
||||
|
||||
|
|
|
@ -638,18 +638,6 @@ def _resolve_pydantic_type(typ) -> List:
|
|||
return typs
|
||||
|
||||
|
||||
def prisma_setup(database_url: Optional[str]):
|
||||
global prisma_client, proxy_logging_obj, user_api_key_cache
|
||||
|
||||
if database_url is not None:
|
||||
try:
|
||||
prisma_client = PrismaClient(
|
||||
database_url=database_url, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
||||
if use_azure_key_vault is False:
|
||||
return
|
||||
|
@ -1543,7 +1531,6 @@ class ProxyConfig:
|
|||
## INIT PROXY REDIS USAGE CLIENT ##
|
||||
redis_usage_cache = litellm.cache.cache
|
||||
|
||||
|
||||
async def get_config(self, config_file_path: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Load config file
|
||||
|
@ -2796,137 +2783,55 @@ def giveup(e):
|
|||
return result
|
||||
|
||||
|
||||
@router.on_event("startup")
|
||||
async def startup_event(): # noqa: PLR0915
|
||||
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db, premium_user, _license_check
|
||||
import json
|
||||
class ProxyStartupEvent:
|
||||
@classmethod
|
||||
def _initialize_startup_logging(
|
||||
cls,
|
||||
llm_router: Optional[litellm.Router],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
redis_usage_cache: Optional[RedisCache],
|
||||
):
|
||||
"""Initialize logging and alerting on startup"""
|
||||
## COST TRACKING ##
|
||||
cost_tracking()
|
||||
|
||||
init_verbose_loggers()
|
||||
## Error Tracking ##
|
||||
error_tracking()
|
||||
|
||||
### LOAD MASTER KEY ###
|
||||
# check if master key set in environment - load from there
|
||||
master_key = get_secret("LITELLM_MASTER_KEY", None) # type: ignore
|
||||
# check if DATABASE_URL in environment - load from there
|
||||
if prisma_client is None:
|
||||
_db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore
|
||||
prisma_setup(database_url=_db_url)
|
||||
proxy_logging_obj.startup_event(
|
||||
llm_router=llm_router, redis_usage_cache=redis_usage_cache
|
||||
)
|
||||
|
||||
### LOAD CONFIG ###
|
||||
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
|
||||
env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH")
|
||||
verbose_proxy_logger.debug("worker_config: %s", worker_config)
|
||||
# check if it's a valid file path
|
||||
if env_config_yaml is not None:
|
||||
if os.path.isfile(env_config_yaml) and proxy_config.is_yaml(
|
||||
config_file_path=env_config_yaml
|
||||
):
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=env_config_yaml
|
||||
)
|
||||
elif worker_config is not None:
|
||||
if (
|
||||
isinstance(worker_config, str)
|
||||
and os.path.isfile(worker_config)
|
||||
and proxy_config.is_yaml(config_file_path=worker_config)
|
||||
):
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=worker_config
|
||||
)
|
||||
elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None and isinstance(
|
||||
worker_config, str
|
||||
):
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=worker_config
|
||||
)
|
||||
elif isinstance(worker_config, dict):
|
||||
await initialize(**worker_config)
|
||||
@classmethod
|
||||
def _initialize_jwt_auth(
|
||||
cls,
|
||||
general_settings: dict,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
):
|
||||
"""Initialize JWT auth on startup"""
|
||||
if general_settings.get("litellm_jwtauth", None) is not None:
|
||||
for k, v in general_settings["litellm_jwtauth"].items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
general_settings["litellm_jwtauth"][k] = get_secret(v)
|
||||
litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"])
|
||||
else:
|
||||
# if not, assume it's a json string
|
||||
worker_config = json.loads(worker_config)
|
||||
if isinstance(worker_config, dict):
|
||||
await initialize(**worker_config)
|
||||
|
||||
## CHECK PREMIUM USER
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.proxy_server.py::startup() - CHECKING PREMIUM USER - {}".format(
|
||||
premium_user
|
||||
litellm_jwtauth = LiteLLM_JWTAuth()
|
||||
jwt_handler.update_environment(
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
litellm_jwtauth=litellm_jwtauth,
|
||||
)
|
||||
)
|
||||
if premium_user is False:
|
||||
premium_user = _license_check.is_premium()
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.proxy_server.py::startup() - PREMIUM USER value - {}".format(
|
||||
premium_user
|
||||
)
|
||||
)
|
||||
|
||||
## COST TRACKING ##
|
||||
cost_tracking()
|
||||
|
||||
## Error Tracking ##
|
||||
error_tracking()
|
||||
|
||||
## UPDATE SLACK ALERTING ##
|
||||
proxy_logging_obj.slack_alerting_instance.update_values(llm_router=llm_router)
|
||||
|
||||
db_writer_client = HTTPHandler()
|
||||
|
||||
## UPDATE INTERNAL USAGE CACHE ##
|
||||
proxy_logging_obj.update_values(
|
||||
redis_cache=redis_usage_cache
|
||||
) # used by parallel request limiter for rate limiting keys across instances
|
||||
|
||||
proxy_logging_obj._init_litellm_callbacks(
|
||||
llm_router=llm_router
|
||||
) # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||
|
||||
if "daily_reports" in proxy_logging_obj.slack_alerting_instance.alert_types:
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.slack_alerting_instance._run_scheduled_daily_report(
|
||||
llm_router=llm_router
|
||||
)
|
||||
) # RUN DAILY REPORT (if scheduled)
|
||||
|
||||
## JWT AUTH ##
|
||||
if general_settings.get("litellm_jwtauth", None) is not None:
|
||||
for k, v in general_settings["litellm_jwtauth"].items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
general_settings["litellm_jwtauth"][k] = get_secret(v)
|
||||
litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"])
|
||||
else:
|
||||
litellm_jwtauth = LiteLLM_JWTAuth()
|
||||
jwt_handler.update_environment(
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
litellm_jwtauth=litellm_jwtauth,
|
||||
)
|
||||
|
||||
if use_background_health_checks:
|
||||
asyncio.create_task(
|
||||
_run_background_health_check()
|
||||
) # start the background health check coroutine.
|
||||
|
||||
if prompt_injection_detection_obj is not None:
|
||||
prompt_injection_detection_obj.update_environment(router=llm_router)
|
||||
|
||||
verbose_proxy_logger.debug("prisma_client: %s", prisma_client)
|
||||
if prisma_client is not None:
|
||||
await prisma_client.connect()
|
||||
|
||||
if prisma_client is not None and master_key is not None:
|
||||
@classmethod
|
||||
def _add_master_key_hash_to_db(
|
||||
cls,
|
||||
master_key: str,
|
||||
prisma_client: PrismaClient,
|
||||
litellm_proxy_admin_name: str,
|
||||
general_settings: dict,
|
||||
):
|
||||
"""Adds master key hash to db for cost tracking"""
|
||||
if os.getenv("PROXY_ADMIN_ID", None) is not None:
|
||||
litellm_proxy_admin_name = os.getenv(
|
||||
"PROXY_ADMIN_ID", litellm_proxy_admin_name
|
||||
|
@ -2951,7 +2856,9 @@ async def startup_event(): # noqa: PLR0915
|
|||
)
|
||||
asyncio.create_task(task_1)
|
||||
|
||||
if prisma_client is not None and litellm.max_budget > 0:
|
||||
@classmethod
|
||||
def _add_proxy_budget_to_db(cls, litellm_proxy_budget_name: str):
|
||||
"""Adds a global proxy budget to db"""
|
||||
if litellm.budget_duration is None:
|
||||
raise Exception(
|
||||
"budget_duration not set on Proxy. budget_duration is required to use max_budget."
|
||||
|
@ -2977,8 +2884,18 @@ async def startup_event(): # noqa: PLR0915
|
|||
)
|
||||
)
|
||||
|
||||
### START BATCH WRITING DB + CHECKING NEW MODELS###
|
||||
if prisma_client is not None:
|
||||
@classmethod
|
||||
async def initialize_scheduled_background_jobs(
|
||||
cls,
|
||||
general_settings: dict,
|
||||
prisma_client: PrismaClient,
|
||||
proxy_budget_rescheduler_min_time: int,
|
||||
proxy_budget_rescheduler_max_time: int,
|
||||
proxy_batch_write_at: int,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
store_model_in_db: bool,
|
||||
):
|
||||
"""Initializes scheduled background jobs"""
|
||||
scheduler = AsyncIOScheduler()
|
||||
interval = random.randint(
|
||||
proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time
|
||||
|
@ -3067,6 +2984,165 @@ async def startup_event(): # noqa: PLR0915
|
|||
|
||||
scheduler.start()
|
||||
|
||||
@classmethod
|
||||
def _setup_prisma_client(
|
||||
cls,
|
||||
database_url: Optional[str],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
user_api_key_cache: DualCache,
|
||||
) -> Optional[PrismaClient]:
|
||||
"""
|
||||
- Sets up prisma client
|
||||
- Adds necessary views to proxy
|
||||
"""
|
||||
prisma_client: Optional[PrismaClient] = None
|
||||
if database_url is not None:
|
||||
try:
|
||||
prisma_client = PrismaClient(
|
||||
database_url=database_url, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
## Add necessary views to proxy ##
|
||||
asyncio.create_task(
|
||||
prisma_client.check_view_exists()
|
||||
) # check if all necessary views exist. Don't block execution
|
||||
|
||||
return prisma_client
|
||||
|
||||
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db, premium_user, _license_check
|
||||
import json
|
||||
|
||||
init_verbose_loggers()
|
||||
|
||||
### LOAD MASTER KEY ###
|
||||
# check if master key set in environment - load from there
|
||||
master_key = get_secret("LITELLM_MASTER_KEY", None) # type: ignore
|
||||
# check if DATABASE_URL in environment - load from there
|
||||
if prisma_client is None:
|
||||
_db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore
|
||||
prisma_client = ProxyStartupEvent._setup_prisma_client(
|
||||
database_url=_db_url,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
|
||||
### LOAD CONFIG ###
|
||||
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
|
||||
env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH")
|
||||
verbose_proxy_logger.debug("worker_config: %s", worker_config)
|
||||
# check if it's a valid file path
|
||||
if env_config_yaml is not None:
|
||||
if os.path.isfile(env_config_yaml) and proxy_config.is_yaml(
|
||||
config_file_path=env_config_yaml
|
||||
):
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=env_config_yaml
|
||||
)
|
||||
elif worker_config is not None:
|
||||
if (
|
||||
isinstance(worker_config, str)
|
||||
and os.path.isfile(worker_config)
|
||||
and proxy_config.is_yaml(config_file_path=worker_config)
|
||||
):
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=worker_config
|
||||
)
|
||||
elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None and isinstance(
|
||||
worker_config, str
|
||||
):
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=worker_config
|
||||
)
|
||||
elif isinstance(worker_config, dict):
|
||||
await initialize(**worker_config)
|
||||
else:
|
||||
# if not, assume it's a json string
|
||||
worker_config = json.loads(worker_config)
|
||||
if isinstance(worker_config, dict):
|
||||
await initialize(**worker_config)
|
||||
|
||||
## CHECK PREMIUM USER
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.proxy_server.py::startup() - CHECKING PREMIUM USER - {}".format(
|
||||
premium_user
|
||||
)
|
||||
)
|
||||
if premium_user is False:
|
||||
premium_user = _license_check.is_premium()
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.proxy_server.py::startup() - PREMIUM USER value - {}".format(
|
||||
premium_user
|
||||
)
|
||||
)
|
||||
|
||||
ProxyStartupEvent._initialize_startup_logging(
|
||||
llm_router=llm_router,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
redis_usage_cache=redis_usage_cache,
|
||||
)
|
||||
|
||||
## JWT AUTH ##
|
||||
ProxyStartupEvent._initialize_jwt_auth(
|
||||
general_settings=general_settings,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
|
||||
if use_background_health_checks:
|
||||
asyncio.create_task(
|
||||
_run_background_health_check()
|
||||
) # start the background health check coroutine.
|
||||
|
||||
if prompt_injection_detection_obj is not None: # [TODO] - REFACTOR THIS
|
||||
prompt_injection_detection_obj.update_environment(router=llm_router)
|
||||
|
||||
verbose_proxy_logger.debug("prisma_client: %s", prisma_client)
|
||||
if prisma_client is not None:
|
||||
await prisma_client.connect()
|
||||
|
||||
if prisma_client is not None and master_key is not None:
|
||||
ProxyStartupEvent._add_master_key_hash_to_db(
|
||||
master_key=master_key,
|
||||
prisma_client=prisma_client,
|
||||
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
||||
general_settings=general_settings,
|
||||
)
|
||||
|
||||
if prisma_client is not None and litellm.max_budget > 0:
|
||||
ProxyStartupEvent._add_proxy_budget_to_db(
|
||||
litellm_proxy_budget_name=litellm_proxy_admin_name
|
||||
)
|
||||
|
||||
### START BATCH WRITING DB + CHECKING NEW MODELS###
|
||||
if prisma_client is not None:
|
||||
await ProxyStartupEvent.initialize_scheduled_background_jobs(
|
||||
general_settings=general_settings,
|
||||
prisma_client=prisma_client,
|
||||
proxy_budget_rescheduler_min_time=proxy_budget_rescheduler_min_time,
|
||||
proxy_budget_rescheduler_max_time=proxy_budget_rescheduler_max_time,
|
||||
proxy_batch_write_at=proxy_batch_write_at,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
store_model_in_db=store_model_in_db,
|
||||
)
|
||||
|
||||
|
||||
#### API ENDPOINTS ####
|
||||
@router.get(
|
||||
|
|
|
@ -349,6 +349,31 @@ class ProxyLogging:
|
|||
)
|
||||
self.premium_user = premium_user
|
||||
|
||||
def startup_event(
|
||||
self,
|
||||
llm_router: Optional[litellm.Router],
|
||||
redis_usage_cache: Optional[RedisCache],
|
||||
):
|
||||
"""Initialize logging and alerting on proxy startup"""
|
||||
## UPDATE SLACK ALERTING ##
|
||||
self.slack_alerting_instance.update_values(llm_router=llm_router)
|
||||
|
||||
## UPDATE INTERNAL USAGE CACHE ##
|
||||
self.update_values(
|
||||
redis_cache=redis_usage_cache
|
||||
) # used by parallel request limiter for rate limiting keys across instances
|
||||
|
||||
self._init_litellm_callbacks(
|
||||
llm_router=llm_router
|
||||
) # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||
|
||||
if "daily_reports" in self.slack_alerting_instance.alert_types:
|
||||
asyncio.create_task(
|
||||
self.slack_alerting_instance._run_scheduled_daily_report(
|
||||
llm_router=llm_router
|
||||
)
|
||||
) # RUN DAILY REPORT (if scheduled)
|
||||
|
||||
def update_values(
|
||||
self,
|
||||
alerting: Optional[List] = None,
|
||||
|
|
|
@ -1894,3 +1894,27 @@ async def test_proxy_model_group_info_rerank(prisma_client):
|
|||
# asyncio.run(test())
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_server_prisma_setup():
|
||||
from litellm.proxy.proxy_server import ProxyStartupEvent
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.caching import DualCache
|
||||
|
||||
user_api_key_cache = DualCache()
|
||||
|
||||
with patch.object(
|
||||
litellm.proxy.proxy_server, "PrismaClient", new=MagicMock()
|
||||
) as mock_prisma_client:
|
||||
mock_client = mock_prisma_client.return_value # This is the mocked instance
|
||||
mock_client.check_view_exists = AsyncMock() # Mock the check_view_exists method
|
||||
|
||||
ProxyStartupEvent._setup_prisma_client(
|
||||
database_url=os.getenv("DATABASE_URL"),
|
||||
proxy_logging_obj=ProxyLogging(user_api_key_cache=user_api_key_cache),
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
mock_client.check_view_exists.assert_called_once()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue