mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(feat) router: init stream, async stream, async, clients
This commit is contained in:
parent
886b52d448
commit
19646091fd
1 changed files with 85 additions and 12 deletions
|
@ -188,7 +188,7 @@ class Router:
|
||||||
data["model"] = original_model_string[:index_of_model_id]
|
data["model"] = original_model_string[:index_of_model_id]
|
||||||
else:
|
else:
|
||||||
data["model"] = original_model_string
|
data["model"] = original_model_string
|
||||||
model_client = deployment.get("client", None)
|
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
|
||||||
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -234,7 +234,7 @@ class Router:
|
||||||
data["model"] = original_model_string[:index_of_model_id]
|
data["model"] = original_model_string[:index_of_model_id]
|
||||||
else:
|
else:
|
||||||
data["model"] = original_model_string
|
data["model"] = original_model_string
|
||||||
model_client = deployment.get("async_client", None)
|
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
||||||
self.total_calls[original_model_string] +=1
|
self.total_calls[original_model_string] +=1
|
||||||
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||||
self.success_calls[original_model_string] +=1
|
self.success_calls[original_model_string] +=1
|
||||||
|
@ -303,7 +303,7 @@ class Router:
|
||||||
data["model"] = original_model_string[:index_of_model_id]
|
data["model"] = original_model_string[:index_of_model_id]
|
||||||
else:
|
else:
|
||||||
data["model"] = original_model_string
|
data["model"] = original_model_string
|
||||||
model_client = deployment.get("client", None)
|
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
|
||||||
# call via litellm.embedding()
|
# call via litellm.embedding()
|
||||||
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||||
|
|
||||||
|
@ -328,7 +328,7 @@ class Router:
|
||||||
data["model"] = original_model_string[:index_of_model_id]
|
data["model"] = original_model_string[:index_of_model_id]
|
||||||
else:
|
else:
|
||||||
data["model"] = original_model_string
|
data["model"] = original_model_string
|
||||||
model_client = deployment.get("async_client", None)
|
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
||||||
|
|
||||||
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||||
|
|
||||||
|
@ -857,19 +857,19 @@ class Router:
|
||||||
if api_version and api_version.startswith("os.environ/"):
|
if api_version and api_version.startswith("os.environ/"):
|
||||||
api_version_env_name = api_version.replace("os.environ/", "")
|
api_version_env_name = api_version.replace("os.environ/", "")
|
||||||
api_version = litellm.get_secret(api_version_env_name)
|
api_version = litellm.get_secret(api_version_env_name)
|
||||||
|
|
||||||
timeout = litellm_params.get("timeout")
|
timeout = litellm_params.pop("timeout", None)
|
||||||
if timeout and timeout.startswith("os.environ/"):
|
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
|
||||||
timeout_env_name = api_version.replace("os.environ/", "")
|
timeout_env_name = api_version.replace("os.environ/", "")
|
||||||
timeout = litellm.get_secret(timeout_env_name)
|
timeout = litellm.get_secret(timeout_env_name)
|
||||||
|
|
||||||
stream_timeout = litellm_params.get("stream_timeout")
|
stream_timeout = litellm_params.pop("stream_timeout", timeout) # if no stream_timeout is set, default to timeout
|
||||||
if stream_timeout and stream_timeout.startswith("os.environ/"):
|
if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"):
|
||||||
stream_timeout_env_name = api_version.replace("os.environ/", "")
|
stream_timeout_env_name = api_version.replace("os.environ/", "")
|
||||||
stream_timeout = litellm.get_secret(stream_timeout_env_name)
|
stream_timeout = litellm.get_secret(stream_timeout_env_name)
|
||||||
|
|
||||||
max_retries = litellm_params.get("max_retries")
|
max_retries = litellm_params.pop("max_retries", 2)
|
||||||
if max_retries and max_retries.startswith("os.environ/"):
|
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
|
||||||
max_retries_env_name = api_version.replace("os.environ/", "")
|
max_retries_env_name = api_version.replace("os.environ/", "")
|
||||||
max_retries = litellm.get_secret(max_retries_env_name)
|
max_retries = litellm.get_secret(max_retries_env_name)
|
||||||
|
|
||||||
|
@ -898,6 +898,22 @@ class Router:
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries
|
max_retries=max_retries
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# streaming clients can have diff timeouts
|
||||||
|
model["stream_async_client"] = openai.AsyncAzureOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=stream_timeout,
|
||||||
|
max_retries=max_retries
|
||||||
|
)
|
||||||
|
model["stream_client"] = openai.AzureOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=stream_timeout,
|
||||||
|
max_retries=max_retries
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
model["async_client"] = openai.AsyncAzureOpenAI(
|
model["async_client"] = openai.AsyncAzureOpenAI(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -913,6 +929,23 @@ class Router:
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries
|
max_retries=max_retries
|
||||||
)
|
)
|
||||||
|
# streaming clients should have diff timeouts
|
||||||
|
model["stream_async_client"] = openai.AsyncAzureOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=stream_timeout,
|
||||||
|
max_retries=max_retries
|
||||||
|
)
|
||||||
|
|
||||||
|
model["stream_client"] = openai.AzureOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=stream_timeout,
|
||||||
|
max_retries=max_retries
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}")
|
self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}")
|
||||||
model["async_client"] = openai.AsyncOpenAI(
|
model["async_client"] = openai.AsyncOpenAI(
|
||||||
|
@ -927,6 +960,23 @@ class Router:
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries
|
max_retries=max_retries
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# streaming clients should have diff timeouts
|
||||||
|
model["stream_async_client"] = openai.AsyncOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
timeout=stream_timeout,
|
||||||
|
max_retries=max_retries
|
||||||
|
)
|
||||||
|
|
||||||
|
# streaming clients should have diff timeouts
|
||||||
|
model["stream_client"] = openai.OpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
timeout=stream_timeout,
|
||||||
|
max_retries=max_retries
|
||||||
|
)
|
||||||
|
|
||||||
############ End of initializing Clients for OpenAI/Azure ###################
|
############ End of initializing Clients for OpenAI/Azure ###################
|
||||||
model_id = ""
|
model_id = ""
|
||||||
for key in model["litellm_params"]:
|
for key in model["litellm_params"]:
|
||||||
|
@ -947,6 +997,29 @@ class Router:
|
||||||
def get_model_names(self):
|
def get_model_names(self):
|
||||||
return self.model_names
|
return self.model_names
|
||||||
|
|
||||||
|
def _get_client(self, deployment, kwargs, client_type=None):
|
||||||
|
"""
|
||||||
|
Returns the appropriate client based on the given deployment, kwargs, and client_type.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
deployment (dict): The deployment dictionary containing the clients.
|
||||||
|
kwargs (dict): The keyword arguments passed to the function.
|
||||||
|
client_type (str): The type of client to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The appropriate client based on the given client_type and kwargs.
|
||||||
|
"""
|
||||||
|
if client_type == "async":
|
||||||
|
if kwargs.get("stream") == True:
|
||||||
|
return deployment["stream_async_client"]
|
||||||
|
else:
|
||||||
|
return deployment["async_client"]
|
||||||
|
else:
|
||||||
|
if kwargs.get("stream") == True:
|
||||||
|
return deployment["stream_client"]
|
||||||
|
else:
|
||||||
|
return deployment["client"]
|
||||||
|
|
||||||
def print_verbose(self, print_statement):
|
def print_verbose(self, print_statement):
|
||||||
if self.set_verbose or litellm.set_verbose:
|
if self.set_verbose or litellm.set_verbose:
|
||||||
print(f"LiteLLM.Router: {print_statement}") # noqa
|
print(f"LiteLLM.Router: {print_statement}") # noqa
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue