forked from phoenix/litellm-mirror
feat(proxy_server.py): enable batch completion fastest response calls on proxy
introduces new `fastest_response` flag for enabling the call
This commit is contained in:
parent
ecd182eb6a
commit
20106715d5
3 changed files with 32 additions and 3 deletions
|
@ -680,6 +680,7 @@ def completion(
|
||||||
"region_name",
|
"region_name",
|
||||||
"allowed_model_region",
|
"allowed_model_region",
|
||||||
"model_config",
|
"model_config",
|
||||||
|
"fastest_response",
|
||||||
]
|
]
|
||||||
|
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
|
|
|
@ -415,6 +415,7 @@ def get_custom_headers(
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
version: Optional[str] = None,
|
version: Optional[str] = None,
|
||||||
model_region: Optional[str] = None,
|
model_region: Optional[str] = None,
|
||||||
|
fastest_response_batch_completion: Optional[bool] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
exclude_values = {"", None}
|
exclude_values = {"", None}
|
||||||
headers = {
|
headers = {
|
||||||
|
@ -425,6 +426,11 @@ def get_custom_headers(
|
||||||
"x-litellm-model-region": model_region,
|
"x-litellm-model-region": model_region,
|
||||||
"x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit),
|
"x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit),
|
||||||
"x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit),
|
"x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit),
|
||||||
|
"x-litellm-fastest_response_batch_completion": (
|
||||||
|
str(fastest_response_batch_completion)
|
||||||
|
if fastest_response_batch_completion is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
return {
|
return {
|
||||||
|
@ -4035,6 +4041,16 @@ async def chat_completion(
|
||||||
elif "," in data["model"] and llm_router is not None:
|
elif "," in data["model"] and llm_router is not None:
|
||||||
_models_csv_string = data.pop("model")
|
_models_csv_string = data.pop("model")
|
||||||
_models = _models_csv_string.split(",")
|
_models = _models_csv_string.split(",")
|
||||||
|
if (
|
||||||
|
data.get("fastest_response", None) is not None
|
||||||
|
and data["fastest_response"] == True
|
||||||
|
):
|
||||||
|
tasks.append(
|
||||||
|
llm_router.abatch_completion_fastest_response(
|
||||||
|
models=_models, **data
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
tasks.append(llm_router.abatch_completion(models=_models, **data))
|
tasks.append(llm_router.abatch_completion(models=_models, **data))
|
||||||
elif "user_config" in data:
|
elif "user_config" in data:
|
||||||
# initialize a new router instance. make request using this Router
|
# initialize a new router instance. make request using this Router
|
||||||
|
@ -4085,6 +4101,9 @@ async def chat_completion(
|
||||||
model_id = hidden_params.get("model_id", None) or ""
|
model_id = hidden_params.get("model_id", None) or ""
|
||||||
cache_key = hidden_params.get("cache_key", None) or ""
|
cache_key = hidden_params.get("cache_key", None) or ""
|
||||||
api_base = hidden_params.get("api_base", None) or ""
|
api_base = hidden_params.get("api_base", None) or ""
|
||||||
|
fastest_response_batch_completion = hidden_params.get(
|
||||||
|
"fastest_response_batch_completion", None
|
||||||
|
)
|
||||||
|
|
||||||
# Post Call Processing
|
# Post Call Processing
|
||||||
if llm_router is not None:
|
if llm_router is not None:
|
||||||
|
@ -4101,6 +4120,7 @@ async def chat_completion(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
version=version,
|
version=version,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
fastest_response_batch_completion=fastest_response_batch_completion,
|
||||||
)
|
)
|
||||||
selected_data_generator = select_data_generator(
|
selected_data_generator = select_data_generator(
|
||||||
response=response,
|
response=response,
|
||||||
|
@ -4121,6 +4141,7 @@ async def chat_completion(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
version=version,
|
version=version,
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
fastest_response_batch_completion=fastest_response_batch_completion,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -804,9 +804,16 @@ class Router:
|
||||||
pending_tasks.append(task)
|
pending_tasks.append(task)
|
||||||
|
|
||||||
responses = await asyncio.gather(*_tasks, return_exceptions=True)
|
responses = await asyncio.gather(*_tasks, return_exceptions=True)
|
||||||
if isinstance(responses[0], Exception):
|
if isinstance(responses[0], Exception) or isinstance(
|
||||||
|
responses[0], BaseException
|
||||||
|
):
|
||||||
raise responses[0]
|
raise responses[0]
|
||||||
return responses[0] # return first value from list
|
_response: Union[ModelResponse, CustomStreamWrapper] = responses[
|
||||||
|
0
|
||||||
|
] # return first value from list
|
||||||
|
|
||||||
|
_response._hidden_params["fastest_response_batch_completion"] = True
|
||||||
|
return _response
|
||||||
|
|
||||||
def image_generation(self, prompt: str, model: str, **kwargs):
|
def image_generation(self, prompt: str, model: str, **kwargs):
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue