fix(router.py): accept dynamic api key

This commit is contained in:
Krrish Dholakia 2023-12-26 13:16:11 +05:30
parent 3029e8a197
commit f5ed4992db

View file

@ -251,7 +251,20 @@ class Router:
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
return litellm.completion(
**{
**data,
@ -303,9 +316,19 @@ class Router:
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = await litellm.acompletion(
**{
@ -361,9 +384,20 @@ class Router:
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = litellm.image_generation(
**{
@ -419,9 +453,20 @@ class Router:
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = await litellm.aimage_generation(
**{
@ -554,8 +599,17 @@ class Router:
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
# call via litellm.embedding()
potential_model_client = self._get_client(deployment=deployment, kwargs=kwargs)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
return litellm.embedding(
**{
**data,
@ -592,9 +646,19 @@ class Router:
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
return await litellm.aembedding(
**{