mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(proxy_server.py): enable batch completion fastest response calls on proxy
introduces new `fastest_response` flag for enabling the call
This commit is contained in:
parent
ecd182eb6a
commit
20106715d5
3 changed files with 32 additions and 3 deletions
|
@ -680,6 +680,7 @@ def completion(
|
|||
"region_name",
|
||||
"allowed_model_region",
|
||||
"model_config",
|
||||
"fastest_response",
|
||||
]
|
||||
|
||||
default_params = openai_params + litellm_params
|
||||
|
|
|
@ -415,6 +415,7 @@ def get_custom_headers(
|
|||
api_base: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
model_region: Optional[str] = None,
|
||||
fastest_response_batch_completion: Optional[bool] = None,
|
||||
) -> dict:
|
||||
exclude_values = {"", None}
|
||||
headers = {
|
||||
|
@ -425,6 +426,11 @@ def get_custom_headers(
|
|||
"x-litellm-model-region": model_region,
|
||||
"x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit),
|
||||
"x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit),
|
||||
"x-litellm-fastest_response_batch_completion": (
|
||||
str(fastest_response_batch_completion)
|
||||
if fastest_response_batch_completion is not None
|
||||
else None
|
||||
),
|
||||
}
|
||||
try:
|
||||
return {
|
||||
|
@ -4035,7 +4041,17 @@ async def chat_completion(
|
|||
elif "," in data["model"] and llm_router is not None:
|
||||
_models_csv_string = data.pop("model")
|
||||
_models = _models_csv_string.split(",")
|
||||
tasks.append(llm_router.abatch_completion(models=_models, **data))
|
||||
if (
|
||||
data.get("fastest_response", None) is not None
|
||||
and data["fastest_response"] == True
|
||||
):
|
||||
tasks.append(
|
||||
llm_router.abatch_completion_fastest_response(
|
||||
models=_models, **data
|
||||
)
|
||||
)
|
||||
else:
|
||||
tasks.append(llm_router.abatch_completion(models=_models, **data))
|
||||
elif "user_config" in data:
|
||||
# initialize a new router instance. make request using this Router
|
||||
router_config = data.pop("user_config")
|
||||
|
@ -4085,6 +4101,9 @@ async def chat_completion(
|
|||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
fastest_response_batch_completion = hidden_params.get(
|
||||
"fastest_response_batch_completion", None
|
||||
)
|
||||
|
||||
# Post Call Processing
|
||||
if llm_router is not None:
|
||||
|
@ -4101,6 +4120,7 @@ async def chat_completion(
|
|||
api_base=api_base,
|
||||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
fastest_response_batch_completion=fastest_response_batch_completion,
|
||||
)
|
||||
selected_data_generator = select_data_generator(
|
||||
response=response,
|
||||
|
@ -4121,6 +4141,7 @@ async def chat_completion(
|
|||
api_base=api_base,
|
||||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
fastest_response_batch_completion=fastest_response_batch_completion,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -804,9 +804,16 @@ class Router:
|
|||
pending_tasks.append(task)
|
||||
|
||||
responses = await asyncio.gather(*_tasks, return_exceptions=True)
|
||||
if isinstance(responses[0], Exception):
|
||||
if isinstance(responses[0], Exception) or isinstance(
|
||||
responses[0], BaseException
|
||||
):
|
||||
raise responses[0]
|
||||
return responses[0] # return first value from list
|
||||
_response: Union[ModelResponse, CustomStreamWrapper] = responses[
|
||||
0
|
||||
] # return first value from list
|
||||
|
||||
_response._hidden_params["fastest_response_batch_completion"] = True
|
||||
return _response
|
||||
|
||||
def image_generation(self, prompt: str, model: str, **kwargs):
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue