(feat) router: re-use the same client for high trafic

This commit is contained in:
ishaan-jaff 2023-11-28 15:42:57 -08:00
parent ad6a3fb8fe
commit 57d774f3ad

View file

@ -216,6 +216,7 @@ class Router:
**kwargs): **kwargs):
try: try:
self.print_verbose(f"Inside _acompletion()- model: {model}; kwargs: {kwargs}") self.print_verbose(f"Inside _acompletion()- model: {model}; kwargs: {kwargs}")
original_model_string = None # set a default for this variable
deployment = self.get_available_deployment(model=model, messages=messages) deployment = self.get_available_deployment(model=model, messages=messages)
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
data = deployment["litellm_params"].copy() data = deployment["litellm_params"].copy()
@ -231,8 +232,9 @@ class Router:
data["model"] = original_model_string[:index_of_model_id] data["model"] = original_model_string[:index_of_model_id]
else: else:
data["model"] = original_model_string data["model"] = original_model_string
model_client = deployment["client"]
self.total_calls[original_model_string] +=1 self.total_calls[original_model_string] +=1
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs}) response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, "azure_client": model_client, **kwargs})
self.success_calls[original_model_string] +=1 self.success_calls[original_model_string] +=1
return response return response
except Exception as e: except Exception as e:
@ -803,6 +805,11 @@ class Router:
for key in model["litellm_params"]: for key in model["litellm_params"]:
model_id+= str(model["litellm_params"][key]) model_id+= str(model["litellm_params"][key])
model["litellm_params"]["model"] += "-ModelID-" + model_id model["litellm_params"]["model"] += "-ModelID-" + model_id
model["client"] = openai.AsyncAzureOpenAI(
api_key= model["litellm_params"]["api_key"],
azure_endpoint = model["litellm_params"]["api_base"]
)
self.model_names = [m["model_name"] for m in model_list] self.model_names = [m["model_name"] for m in model_list]
def get_model_names(self): def get_model_names(self):