forked from phoenix/litellm-mirror
(fix) router: passing client
This commit is contained in:
parent
4d06c296e3
commit
282b9a37e5
2 changed files with 25 additions and 16 deletions
|
@ -1765,6 +1765,7 @@ def embedding(
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
model_response=EmbeddingResponse(),
|
model_response=EmbeddingResponse(),
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
client=optional_params.pop("client", None)
|
||||||
)
|
)
|
||||||
elif model in litellm.open_ai_embedding_models or custom_llm_provider == "openai":
|
elif model in litellm.open_ai_embedding_models or custom_llm_provider == "openai":
|
||||||
api_base = (
|
api_base = (
|
||||||
|
@ -1798,7 +1799,8 @@ def embedding(
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
model_response=EmbeddingResponse(),
|
model_response=EmbeddingResponse(),
|
||||||
optional_params=optional_params
|
optional_params=optional_params,
|
||||||
|
client=optional_params.pop("client", None)
|
||||||
)
|
)
|
||||||
elif model in litellm.cohere_embedding_models:
|
elif model in litellm.cohere_embedding_models:
|
||||||
cohere_key = (
|
cohere_key = (
|
||||||
|
|
|
@ -188,7 +188,8 @@ class Router:
|
||||||
data["model"] = original_model_string[:index_of_model_id]
|
data["model"] = original_model_string[:index_of_model_id]
|
||||||
else:
|
else:
|
||||||
data["model"] = original_model_string
|
data["model"] = original_model_string
|
||||||
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
|
model_client = deployment.get("client", None)
|
||||||
|
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -232,7 +233,7 @@ class Router:
|
||||||
data["model"] = original_model_string[:index_of_model_id]
|
data["model"] = original_model_string[:index_of_model_id]
|
||||||
else:
|
else:
|
||||||
data["model"] = original_model_string
|
data["model"] = original_model_string
|
||||||
model_client = deployment["async_client"]
|
model_client = deployment.get("async_client", None)
|
||||||
self.total_calls[original_model_string] +=1
|
self.total_calls[original_model_string] +=1
|
||||||
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||||
self.success_calls[original_model_string] +=1
|
self.success_calls[original_model_string] +=1
|
||||||
|
@ -301,8 +302,9 @@ class Router:
|
||||||
data["model"] = original_model_string[:index_of_model_id]
|
data["model"] = original_model_string[:index_of_model_id]
|
||||||
else:
|
else:
|
||||||
data["model"] = original_model_string
|
data["model"] = original_model_string
|
||||||
|
model_client = deployment.get("client", None)
|
||||||
# call via litellm.embedding()
|
# call via litellm.embedding()
|
||||||
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
|
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||||
|
|
||||||
async def aembedding(self,
|
async def aembedding(self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -325,7 +327,9 @@ class Router:
|
||||||
data["model"] = original_model_string[:index_of_model_id]
|
data["model"] = original_model_string[:index_of_model_id]
|
||||||
else:
|
else:
|
||||||
data["model"] = original_model_string
|
data["model"] = original_model_string
|
||||||
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
|
model_client = deployment.get("async_client", None)
|
||||||
|
|
||||||
|
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||||
|
|
||||||
async def async_function_with_fallbacks(self, *args, **kwargs):
|
async def async_function_with_fallbacks(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -801,23 +805,26 @@ class Router:
|
||||||
self.model_list = model_list
|
self.model_list = model_list
|
||||||
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
|
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
|
||||||
for model in self.model_list:
|
for model in self.model_list:
|
||||||
if "azure" in model["litellm_params"]["model"]:
|
litellm_params = model.get("litellm_params", {})
|
||||||
|
model_name = litellm_params.get("model")
|
||||||
|
|
||||||
|
if "azure" in model_name:
|
||||||
model["async_client"] = openai.AsyncAzureOpenAI(
|
model["async_client"] = openai.AsyncAzureOpenAI(
|
||||||
api_key= model["litellm_params"]["api_key"],
|
api_key=litellm_params.get("api_key"),
|
||||||
azure_endpoint = model["litellm_params"]["api_base"]
|
azure_endpoint=litellm_params.get("api_base")
|
||||||
)
|
)
|
||||||
model["client"] = openai.AzureOpenAI(
|
model["client"] = openai.AzureOpenAI(
|
||||||
api_key= model["litellm_params"]["api_key"],
|
api_key=litellm_params.get("api_key"),
|
||||||
azure_endpoint = model["litellm_params"]["api_base"]
|
azure_endpoint=litellm_params.get("api_base")
|
||||||
)
|
)
|
||||||
elif model["litellm_params"]["model"] in litellm.open_ai_chat_completion_models:
|
elif model_name in litellm.open_ai_chat_completion_models:
|
||||||
model["async_client"] = openai.AsyncOpenAI(
|
model["async_client"] = openai.AsyncOpenAI(
|
||||||
api_key= model["litellm_params"]["api_key"],
|
api_key=litellm_params.get("api_key"),
|
||||||
base_ur = model["litellm_params"]["api_base"]
|
base_url=litellm_params.get("api_base")
|
||||||
)
|
)
|
||||||
model["client"] = openai.OpenAI(
|
model["client"] = openai.OpenAI(
|
||||||
api_key= model["litellm_params"]["api_key"],
|
api_key=litellm_params.get("api_key"),
|
||||||
base_url = model["litellm_params"]["api_base"]
|
base_url=litellm_params.get("api_base")
|
||||||
)
|
)
|
||||||
model_id = ""
|
model_id = ""
|
||||||
for key in model["litellm_params"]:
|
for key in model["litellm_params"]:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue