Merge branch 'main' into litellm_pass_through_endpoints_api

This commit is contained in:
Krish Dholakia 2024-08-15 22:39:19 -07:00 committed by GitHub
commit b3d15ace89
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
55 changed files with 3658 additions and 1644 deletions

View file

@ -159,6 +159,7 @@ from litellm.proxy.common_utils.http_parsing_utils import (
check_file_size_under_limit,
)
from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy
from litellm.proxy.common_utils.load_config_utils import get_file_contents_from_s3
from litellm.proxy.common_utils.openai_endpoint_utils import (
remove_sensitive_info_from_deployment,
)
@ -197,6 +198,8 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
router as pass_through_router,
)
from litellm.proxy.route_llm_request import route_request
from litellm.proxy.secret_managers.aws_secret_manager import (
load_aws_kms,
load_aws_secret_manager,
@ -1444,7 +1447,18 @@ class ProxyConfig:
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details
# Load existing config
config = await self.get_config(config_file_path=config_file_path)
if os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
bucket_name = os.environ.get("LITELLM_CONFIG_BUCKET_NAME")
object_key = os.environ.get("LITELLM_CONFIG_BUCKET_OBJECT_KEY")
verbose_proxy_logger.debug(
"bucket_name: %s, object_key: %s", bucket_name, object_key
)
config = get_file_contents_from_s3(
bucket_name=bucket_name, object_key=object_key
)
else:
# default to file
config = await self.get_config(config_file_path=config_file_path)
## PRINT YAML FOR CONFIRMING IT WORKS
printed_yaml = copy.deepcopy(config)
printed_yaml.pop("environment_variables", None)
@ -2652,6 +2666,15 @@ async def startup_event():
)
else:
await initialize(**worker_config)
elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config
)
else:
# if not, assume it's a json string
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
@ -3036,68 +3059,13 @@ async def chat_completion(
### ROUTE THE REQUEST ###
# Do not change this - it should be a constant time fetch - ALWAYS
router_model_names = llm_router.model_names if llm_router is not None else []
# skip router if user passed their key
if "api_key" in data:
tasks.append(litellm.acompletion(**data))
elif "," in data["model"] and llm_router is not None:
if (
data.get("fastest_response", None) is not None
and data["fastest_response"] == True
):
tasks.append(llm_router.abatch_completion_fastest_response(**data))
else:
_models_csv_string = data.pop("model")
_models = [model.strip() for model in _models_csv_string.split(",")]
tasks.append(llm_router.abatch_completion(models=_models, **data))
elif "user_config" in data:
# initialize a new router instance. make request using this Router
router_config = data.pop("user_config")
user_router = litellm.Router(**router_config)
tasks.append(user_router.acompletion(**data))
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
tasks.append(llm_router.acompletion(**data))
elif (
llm_router is not None and data["model"] in llm_router.get_model_ids()
): # model in router model list
tasks.append(llm_router.acompletion(**data))
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
tasks.append(llm_router.acompletion(**data))
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
tasks.append(llm_router.acompletion(**data, specific_deployment=True))
elif (
llm_router is not None
and data["model"] not in router_model_names
and llm_router.router_general_settings.pass_through_all_models is True
):
tasks.append(litellm.acompletion(**data))
elif (
llm_router is not None
and data["model"] not in router_model_names
and (
llm_router.default_deployment is not None
or len(llm_router.provider_default_deployments) > 0
)
): # model in router deployments, calling a specific deployment on the router
tasks.append(llm_router.acompletion(**data))
elif user_model is not None: # `litellm --model <your-model-name>`
tasks.append(litellm.acompletion(**data))
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "chat_completion: Invalid model name passed in model="
+ data.get("model", "")
},
)
llm_call = await route_request(
data=data,
route_type="acompletion",
llm_router=llm_router,
user_model=user_model,
)
tasks.append(llm_call)
# wait for call to end
llm_responses = asyncio.gather(
@ -3320,58 +3288,15 @@ async def completion(
)
### ROUTE THE REQUESTs ###
router_model_names = llm_router.model_names if llm_router is not None else []
# skip router if user passed their key
if "api_key" in data:
llm_response = asyncio.create_task(litellm.atext_completion(**data))
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
llm_response = asyncio.create_task(
llm_router.atext_completion(**data, specific_deployment=True)
)
elif (
llm_router is not None and data["model"] in llm_router.get_model_ids()
): # model in router model list
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
elif (
llm_router is not None
and data["model"] not in router_model_names
and llm_router.router_general_settings.pass_through_all_models is True
):
llm_response = asyncio.create_task(litellm.atext_completion(**data))
elif (
llm_router is not None
and data["model"] not in router_model_names
and (
llm_router.default_deployment is not None
or len(llm_router.provider_default_deployments) > 0
)
): # model in router deployments, calling a specific deployment on the router
llm_response = asyncio.create_task(llm_router.atext_completion(**data))
elif user_model is not None: # `litellm --model <your-model-name>`
llm_response = asyncio.create_task(litellm.atext_completion(**data))
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "completion: Invalid model name passed in model="
+ data.get("model", "")
},
)
llm_call = await route_request(
data=data,
route_type="atext_completion",
llm_router=llm_router,
user_model=user_model,
)
# Await the llm_response task
response = await llm_response
response = await llm_call
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
@ -3585,59 +3510,13 @@ async def embeddings(
)
## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key
if "api_key" in data:
tasks.append(litellm.aembedding(**data))
elif "user_config" in data:
# initialize a new router instance. make request using this Router
router_config = data.pop("user_config")
user_router = litellm.Router(**router_config)
tasks.append(user_router.aembedding(**data))
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
tasks.append(llm_router.aembedding(**data))
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
tasks.append(
llm_router.aembedding(**data)
) # ensure this goes the llm_router, router will do the correct alias mapping
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
tasks.append(llm_router.aembedding(**data, specific_deployment=True))
elif (
llm_router is not None and data["model"] in llm_router.get_model_ids()
): # model in router deployments, calling a specific deployment on the router
tasks.append(llm_router.aembedding(**data))
elif (
llm_router is not None
and data["model"] not in router_model_names
and llm_router.router_general_settings.pass_through_all_models is True
):
tasks.append(litellm.aembedding(**data))
elif (
llm_router is not None
and data["model"] not in router_model_names
and (
llm_router.default_deployment is not None
or len(llm_router.provider_default_deployments) > 0
)
): # model in router deployments, calling a specific deployment on the router
tasks.append(llm_router.aembedding(**data))
elif user_model is not None: # `litellm --model <your-model-name>`
tasks.append(litellm.aembedding(**data))
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "embeddings: Invalid model name passed in model="
+ data.get("model", "")
},
)
llm_call = await route_request(
data=data,
route_type="aembedding",
llm_router=llm_router,
user_model=user_model,
)
tasks.append(llm_call)
# wait for call to end
llm_responses = asyncio.gather(
@ -3768,46 +3647,13 @@ async def image_generation(
)
## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key
if "api_key" in data:
response = await litellm.aimage_generation(**data)
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
response = await llm_router.aimage_generation(**data)
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.aimage_generation(
**data, specific_deployment=True
)
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
response = await llm_router.aimage_generation(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
elif (
llm_router is not None
and data["model"] not in router_model_names
and (
llm_router.default_deployment is not None
or len(llm_router.provider_default_deployments) > 0
)
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.aimage_generation(**data)
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.aimage_generation(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "image_generation: Invalid model name passed in model="
+ data.get("model", "")
},
)
llm_call = await route_request(
data=data,
route_type="aimage_generation",
llm_router=llm_router,
user_model=user_model,
)
response = await llm_call
### ALERTING ###
asyncio.create_task(
@ -3915,44 +3761,13 @@ async def audio_speech(
)
## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key
if "api_key" in data:
response = await litellm.aspeech(**data)
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
response = await llm_router.aspeech(**data)
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.aspeech(**data, specific_deployment=True)
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
response = await llm_router.aspeech(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
elif (
llm_router is not None
and data["model"] not in router_model_names
and (
llm_router.default_deployment is not None
or len(llm_router.provider_default_deployments) > 0
)
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.aspeech(**data)
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.aspeech(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "audio_speech: Invalid model name passed in model="
+ data.get("model", "")
},
)
llm_call = await route_request(
data=data,
route_type="aspeech",
llm_router=llm_router,
user_model=user_model,
)
response = await llm_call
### ALERTING ###
asyncio.create_task(
@ -4085,47 +3900,13 @@ async def audio_transcriptions(
)
## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key
if "api_key" in data:
response = await litellm.atranscription(**data)
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
response = await llm_router.atranscription(**data)
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.atranscription(
**data, specific_deployment=True
)
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
response = await llm_router.atranscription(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
elif (
llm_router is not None
and data["model"] not in router_model_names
and (
llm_router.default_deployment is not None
or len(llm_router.provider_default_deployments) > 0
)
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.atranscription(**data)
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.atranscription(**data)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "audio_transcriptions: Invalid model name passed in model="
+ data.get("model", "")
},
)
llm_call = await route_request(
data=data,
route_type="atranscription",
llm_router=llm_router,
user_model=user_model,
)
response = await llm_call
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
@ -5341,40 +5122,13 @@ async def moderations(
start_time = time.time()
## ROUTE TO CORRECT ENDPOINT ##
# skip router if user passed their key
if "api_key" in data:
response = await litellm.amoderation(**data)
elif (
llm_router is not None and data.get("model") in router_model_names
): # model in router model list
response = await llm_router.amoderation(**data)
elif (
llm_router is not None and data.get("model") in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.amoderation(**data, specific_deployment=True)
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data.get("model") in llm_router.model_group_alias
): # model set in model_group_alias
response = await llm_router.amoderation(
**data
) # ensure this goes the llm_router, router will do the correct alias mapping
elif (
llm_router is not None
and data.get("model") not in router_model_names
and (
llm_router.default_deployment is not None
or len(llm_router.provider_default_deployments) > 0
)
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.amoderation(**data)
elif user_model is not None: # `litellm --model <your-model-name>`
response = await litellm.amoderation(**data)
else:
# /moderations does not need a "model" passed
# see https://platform.openai.com/docs/api-reference/moderations
response = await litellm.amoderation(**data)
llm_call = await route_request(
data=data,
route_type="amoderation",
llm_router=llm_router,
user_model=user_model,
)
response = await llm_call
### ALERTING ###
asyncio.create_task(