mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(feat) router: init stream, async stream, async, clients
This commit is contained in:
parent
886b52d448
commit
19646091fd
1 changed files with 85 additions and 12 deletions
|
@ -188,7 +188,7 @@ class Router:
|
|||
data["model"] = original_model_string[:index_of_model_id]
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
model_client = deployment.get("client", None)
|
||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
|
||||
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
@ -234,7 +234,7 @@ class Router:
|
|||
data["model"] = original_model_string[:index_of_model_id]
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
model_client = deployment.get("async_client", None)
|
||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
||||
self.total_calls[original_model_string] +=1
|
||||
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||
self.success_calls[original_model_string] +=1
|
||||
|
@ -303,7 +303,7 @@ class Router:
|
|||
data["model"] = original_model_string[:index_of_model_id]
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
model_client = deployment.get("client", None)
|
||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
|
||||
# call via litellm.embedding()
|
||||
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||
|
||||
|
@ -328,7 +328,7 @@ class Router:
|
|||
data["model"] = original_model_string[:index_of_model_id]
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
model_client = deployment.get("async_client", None)
|
||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
||||
|
||||
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||
|
||||
|
@ -857,19 +857,19 @@ class Router:
|
|||
if api_version and api_version.startswith("os.environ/"):
|
||||
api_version_env_name = api_version.replace("os.environ/", "")
|
||||
api_version = litellm.get_secret(api_version_env_name)
|
||||
|
||||
timeout = litellm_params.get("timeout")
|
||||
if timeout and timeout.startswith("os.environ/"):
|
||||
|
||||
timeout = litellm_params.pop("timeout", None)
|
||||
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
|
||||
timeout_env_name = api_version.replace("os.environ/", "")
|
||||
timeout = litellm.get_secret(timeout_env_name)
|
||||
|
||||
stream_timeout = litellm_params.get("stream_timeout")
|
||||
if stream_timeout and stream_timeout.startswith("os.environ/"):
|
||||
stream_timeout = litellm_params.pop("stream_timeout", timeout) # if no stream_timeout is set, default to timeout
|
||||
if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"):
|
||||
stream_timeout_env_name = api_version.replace("os.environ/", "")
|
||||
stream_timeout = litellm.get_secret(stream_timeout_env_name)
|
||||
|
||||
max_retries = litellm_params.get("max_retries")
|
||||
if max_retries and max_retries.startswith("os.environ/"):
|
||||
|
||||
max_retries = litellm_params.pop("max_retries", 2)
|
||||
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
|
||||
max_retries_env_name = api_version.replace("os.environ/", "")
|
||||
max_retries = litellm.get_secret(max_retries_env_name)
|
||||
|
||||
|
@ -898,6 +898,22 @@ class Router:
|
|||
timeout=timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
# streaming clients can have diff timeouts
|
||||
model["stream_async_client"] = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
model["stream_client"] = openai.AzureOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
else:
|
||||
model["async_client"] = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
|
@ -913,6 +929,23 @@ class Router:
|
|||
timeout=timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
# streaming clients should have diff timeouts
|
||||
model["stream_async_client"] = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
model["stream_client"] = openai.AzureOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
else:
|
||||
self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}")
|
||||
model["async_client"] = openai.AsyncOpenAI(
|
||||
|
@ -927,6 +960,23 @@ class Router:
|
|||
timeout=timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
# streaming clients should have diff timeouts
|
||||
model["stream_async_client"] = openai.AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
# streaming clients should have diff timeouts
|
||||
model["stream_client"] = openai.OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
############ End of initializing Clients for OpenAI/Azure ###################
|
||||
model_id = ""
|
||||
for key in model["litellm_params"]:
|
||||
|
@ -947,6 +997,29 @@ class Router:
|
|||
def get_model_names(self):
|
||||
return self.model_names
|
||||
|
||||
def _get_client(self, deployment, kwargs, client_type=None):
|
||||
"""
|
||||
Returns the appropriate client based on the given deployment, kwargs, and client_type.
|
||||
|
||||
Parameters:
|
||||
deployment (dict): The deployment dictionary containing the clients.
|
||||
kwargs (dict): The keyword arguments passed to the function.
|
||||
client_type (str): The type of client to return.
|
||||
|
||||
Returns:
|
||||
The appropriate client based on the given client_type and kwargs.
|
||||
"""
|
||||
if client_type == "async":
|
||||
if kwargs.get("stream") == True:
|
||||
return deployment["stream_async_client"]
|
||||
else:
|
||||
return deployment["async_client"]
|
||||
else:
|
||||
if kwargs.get("stream") == True:
|
||||
return deployment["stream_client"]
|
||||
else:
|
||||
return deployment["client"]
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
if self.set_verbose or litellm.set_verbose:
|
||||
print(f"LiteLLM.Router: {print_statement}") # noqa
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue