(feat) router: init stream, async stream, async, clients

This commit is contained in:
ishaan-jaff 2023-12-04 17:29:07 -08:00
parent 886b52d448
commit 19646091fd

View file

@ -188,7 +188,7 @@ class Router:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
model_client = deployment.get("client", None)
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
except Exception as e:
raise e
@ -234,7 +234,7 @@ class Router:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
model_client = deployment.get("async_client", None)
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
self.total_calls[original_model_string] +=1
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
self.success_calls[original_model_string] +=1
@ -303,7 +303,7 @@ class Router:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
model_client = deployment.get("client", None)
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
# call via litellm.embedding()
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
@ -328,7 +328,7 @@ class Router:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
model_client = deployment.get("async_client", None)
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
@ -857,19 +857,19 @@ class Router:
if api_version and api_version.startswith("os.environ/"):
api_version_env_name = api_version.replace("os.environ/", "")
api_version = litellm.get_secret(api_version_env_name)
timeout = litellm_params.get("timeout")
if timeout and timeout.startswith("os.environ/"):
timeout = litellm_params.pop("timeout", None)
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
timeout_env_name = api_version.replace("os.environ/", "")
timeout = litellm.get_secret(timeout_env_name)
stream_timeout = litellm_params.get("stream_timeout")
if stream_timeout and stream_timeout.startswith("os.environ/"):
stream_timeout = litellm_params.pop("stream_timeout", timeout) # if no stream_timeout is set, default to timeout
if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"):
stream_timeout_env_name = api_version.replace("os.environ/", "")
stream_timeout = litellm.get_secret(stream_timeout_env_name)
max_retries = litellm_params.get("max_retries")
if max_retries and max_retries.startswith("os.environ/"):
max_retries = litellm_params.pop("max_retries", 2)
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
max_retries_env_name = api_version.replace("os.environ/", "")
max_retries = litellm.get_secret(max_retries_env_name)
@ -898,6 +898,22 @@ class Router:
timeout=timeout,
max_retries=max_retries
)
# streaming clients can have diff timeouts
model["stream_async_client"] = openai.AsyncAzureOpenAI(
api_key=api_key,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries
)
model["stream_client"] = openai.AzureOpenAI(
api_key=api_key,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries
)
else:
model["async_client"] = openai.AsyncAzureOpenAI(
api_key=api_key,
@ -913,6 +929,23 @@ class Router:
timeout=timeout,
max_retries=max_retries
)
# streaming clients should have diff timeouts
model["stream_async_client"] = openai.AsyncAzureOpenAI(
api_key=api_key,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries
)
model["stream_client"] = openai.AzureOpenAI(
api_key=api_key,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries
)
else:
self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}")
model["async_client"] = openai.AsyncOpenAI(
@ -927,6 +960,23 @@ class Router:
timeout=timeout,
max_retries=max_retries
)
# streaming clients should have diff timeouts
model["stream_async_client"] = openai.AsyncOpenAI(
api_key=api_key,
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries
)
# streaming clients should have diff timeouts
model["stream_client"] = openai.OpenAI(
api_key=api_key,
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries
)
############ End of initializing Clients for OpenAI/Azure ###################
model_id = ""
for key in model["litellm_params"]:
@ -947,6 +997,29 @@ class Router:
def get_model_names(self):
return self.model_names
def _get_client(self, deployment, kwargs, client_type=None):
"""
Returns the appropriate client based on the given deployment, kwargs, and client_type.
Parameters:
deployment (dict): The deployment dictionary containing the clients.
kwargs (dict): The keyword arguments passed to the function.
client_type (str): The type of client to return.
Returns:
The appropriate client based on the given client_type and kwargs.
"""
if client_type == "async":
if kwargs.get("stream") == True:
return deployment["stream_async_client"]
else:
return deployment["async_client"]
else:
if kwargs.get("stream") == True:
return deployment["stream_client"]
else:
return deployment["client"]
def print_verbose(self, print_statement):
if self.set_verbose or litellm.set_verbose:
print(f"LiteLLM.Router: {print_statement}") # noqa