(fix) router: passing client

This commit is contained in:
ishaan-jaff 2023-11-28 16:34:16 -08:00
parent 4d06c296e3
commit 282b9a37e5
2 changed files with 25 additions and 16 deletions

View file

@ -1765,6 +1765,7 @@ def embedding(
timeout=timeout,
model_response=EmbeddingResponse(),
optional_params=optional_params,
client=optional_params.pop("client", None)
)
elif model in litellm.open_ai_embedding_models or custom_llm_provider == "openai":
api_base = (
@ -1798,7 +1799,8 @@ def embedding(
logging_obj=logging,
timeout=timeout,
model_response=EmbeddingResponse(),
optional_params=optional_params
optional_params=optional_params,
client=optional_params.pop("client", None)
)
elif model in litellm.cohere_embedding_models:
cohere_key = (

View file

@ -102,7 +102,7 @@ class Router:
self.default_litellm_params["timeout"] = timeout
self.default_litellm_params["max_retries"] = 0
### HEALTH CHECK THREAD ###
if self.routing_strategy == "least-busy":
self._start_health_check_thread()
@ -188,7 +188,8 @@ class Router:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
model_client = deployment.get("client", None)
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
except Exception as e:
raise e
@ -232,7 +233,7 @@ class Router:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
model_client = deployment["async_client"]
model_client = deployment.get("async_client", None)
self.total_calls[original_model_string] +=1
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
self.success_calls[original_model_string] +=1
@ -301,8 +302,9 @@ class Router:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
model_client = deployment.get("client", None)
# call via litellm.embedding()
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
async def aembedding(self,
model: str,
@ -325,7 +327,9 @@ class Router:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
model_client = deployment.get("async_client", None)
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
async def async_function_with_fallbacks(self, *args, **kwargs):
"""
@ -801,23 +805,26 @@ class Router:
self.model_list = model_list
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
for model in self.model_list:
if "azure" in model["litellm_params"]["model"]:
litellm_params = model.get("litellm_params", {})
model_name = litellm_params.get("model")
if "azure" in model_name:
model["async_client"] = openai.AsyncAzureOpenAI(
api_key= model["litellm_params"]["api_key"],
azure_endpoint = model["litellm_params"]["api_base"]
api_key=litellm_params.get("api_key"),
azure_endpoint=litellm_params.get("api_base")
)
model["client"] = openai.AzureOpenAI(
api_key= model["litellm_params"]["api_key"],
azure_endpoint = model["litellm_params"]["api_base"]
api_key=litellm_params.get("api_key"),
azure_endpoint=litellm_params.get("api_base")
)
elif model["litellm_params"]["model"] in litellm.open_ai_chat_completion_models:
elif model_name in litellm.open_ai_chat_completion_models:
model["async_client"] = openai.AsyncOpenAI(
api_key= model["litellm_params"]["api_key"],
base_ur = model["litellm_params"]["api_base"]
api_key=litellm_params.get("api_key"),
base_url=litellm_params.get("api_base")
)
model["client"] = openai.OpenAI(
api_key= model["litellm_params"]["api_key"],
base_url = model["litellm_params"]["api_base"]
api_key=litellm_params.get("api_key"),
base_url=litellm_params.get("api_base")
)
model_id = ""
for key in model["litellm_params"]: