mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
(fix) router: passing client
This commit is contained in:
parent
4d06c296e3
commit
282b9a37e5
2 changed files with 25 additions and 16 deletions
|
@ -102,7 +102,7 @@ class Router:
|
|||
self.default_litellm_params["timeout"] = timeout
|
||||
self.default_litellm_params["max_retries"] = 0
|
||||
|
||||
|
||||
|
||||
### HEALTH CHECK THREAD ###
|
||||
if self.routing_strategy == "least-busy":
|
||||
self._start_health_check_thread()
|
||||
|
@ -188,7 +188,8 @@ class Router:
|
|||
data["model"] = original_model_string[:index_of_model_id]
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
|
||||
model_client = deployment.get("client", None)
|
||||
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -232,7 +233,7 @@ class Router:
|
|||
data["model"] = original_model_string[:index_of_model_id]
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
model_client = deployment["async_client"]
|
||||
model_client = deployment.get("async_client", None)
|
||||
self.total_calls[original_model_string] +=1
|
||||
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||
self.success_calls[original_model_string] +=1
|
||||
|
@ -301,8 +302,9 @@ class Router:
|
|||
data["model"] = original_model_string[:index_of_model_id]
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
model_client = deployment.get("client", None)
|
||||
# call via litellm.embedding()
|
||||
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
|
||||
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||
|
||||
async def aembedding(self,
|
||||
model: str,
|
||||
|
@ -325,7 +327,9 @@ class Router:
|
|||
data["model"] = original_model_string[:index_of_model_id]
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
|
||||
model_client = deployment.get("async_client", None)
|
||||
|
||||
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||
|
||||
async def async_function_with_fallbacks(self, *args, **kwargs):
|
||||
"""
|
||||
|
@ -801,23 +805,26 @@ class Router:
|
|||
self.model_list = model_list
|
||||
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
|
||||
for model in self.model_list:
|
||||
if "azure" in model["litellm_params"]["model"]:
|
||||
litellm_params = model.get("litellm_params", {})
|
||||
model_name = litellm_params.get("model")
|
||||
|
||||
if "azure" in model_name:
|
||||
model["async_client"] = openai.AsyncAzureOpenAI(
|
||||
api_key= model["litellm_params"]["api_key"],
|
||||
azure_endpoint = model["litellm_params"]["api_base"]
|
||||
api_key=litellm_params.get("api_key"),
|
||||
azure_endpoint=litellm_params.get("api_base")
|
||||
)
|
||||
model["client"] = openai.AzureOpenAI(
|
||||
api_key= model["litellm_params"]["api_key"],
|
||||
azure_endpoint = model["litellm_params"]["api_base"]
|
||||
api_key=litellm_params.get("api_key"),
|
||||
azure_endpoint=litellm_params.get("api_base")
|
||||
)
|
||||
elif model["litellm_params"]["model"] in litellm.open_ai_chat_completion_models:
|
||||
elif model_name in litellm.open_ai_chat_completion_models:
|
||||
model["async_client"] = openai.AsyncOpenAI(
|
||||
api_key= model["litellm_params"]["api_key"],
|
||||
base_ur = model["litellm_params"]["api_base"]
|
||||
api_key=litellm_params.get("api_key"),
|
||||
base_url=litellm_params.get("api_base")
|
||||
)
|
||||
model["client"] = openai.OpenAI(
|
||||
api_key= model["litellm_params"]["api_key"],
|
||||
base_url = model["litellm_params"]["api_base"]
|
||||
api_key=litellm_params.get("api_key"),
|
||||
base_url=litellm_params.get("api_base")
|
||||
)
|
||||
model_id = ""
|
||||
for key in model["litellm_params"]:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue