Merge branch 'main' into litellm_imp_mem_use

This commit is contained in:
Ishaan Jaff 2024-03-11 19:00:56 -07:00 committed by GitHub
commit cd8f25f6f8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 296 additions and 65 deletions

View file

@ -967,44 +967,81 @@ class Router:
is_async: Optional[bool] = False,
**kwargs,
) -> Union[List[float], None]:
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(
model=model,
input=input,
specific_deployment=kwargs.pop("specific_deployment", None),
)
kwargs.setdefault("model_info", {})
kwargs.setdefault("metadata", {}).update(
{"model_group": model, "deployment": deployment["litellm_params"]["model"]}
) # [TODO]: move to using async_function_with_fallbacks
data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items():
try:
kwargs["model"] = model
kwargs["input"] = input
kwargs["original_function"] = self._embedding
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = self.function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
def _embedding(self, input: Union[str, List], model: str, **kwargs):
try:
verbose_router_logger.debug(
f"Inside embedding()- model: {model}; kwargs: {kwargs}"
)
deployment = self.get_available_deployment(
model=model,
input=input,
specific_deployment=kwargs.pop("specific_deployment", None),
)
kwargs.setdefault("metadata", {}).update(
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="sync"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
k not in kwargs
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
potential_model_client = self._get_client(deployment=deployment, kwargs=kwargs)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
return litellm.embedding(
**{
**data,
"input": input,
"caching": self.cache_responses,
"client": model_client,
**kwargs,
}
)
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = litellm.embedding(
**{
**data,
"input": input,
"caching": self.cache_responses,
"client": model_client,
**kwargs,
}
)
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.embedding(model={model_name})\033[32m 200 OK\033[0m"
)
return response
except Exception as e:
verbose_router_logger.info(
f"litellm.embedding(model={model_name})\033[31m Exception {str(e)}\033[0m"
)
if model_name is not None:
self.fail_calls[model_name] += 1
raise e
async def aembedding(
self,