mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
(fix) batch_completion fails with bedrock due to extraneous [max_workers] key (#6176)
* fix batch_completion * fix import batch completion * fix batch completion usage
This commit is contained in:
parent
11f9df923a
commit
b032e898c2
5 changed files with 269 additions and 244 deletions
|
@ -1041,6 +1041,7 @@ from .proxy.proxy_cli import run_server
|
||||||
from .router import Router
|
from .router import Router
|
||||||
from .assistants.main import *
|
from .assistants.main import *
|
||||||
from .batches.main import *
|
from .batches.main import *
|
||||||
|
from .batch_completion.main import *
|
||||||
from .rerank_api.main import *
|
from .rerank_api.main import *
|
||||||
from .realtime_api.main import _arealtime
|
from .realtime_api.main import _arealtime
|
||||||
from .fine_tuning.main import *
|
from .fine_tuning.main import *
|
||||||
|
|
11
litellm/batch_completion/Readme.md
Normal file
11
litellm/batch_completion/Readme.md
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
# Implementation of `litellm.batch_completion`, `litellm.batch_completion_models`, `litellm.batch_completion_models_all_responses`
|
||||||
|
|
||||||
|
Doc: https://docs.litellm.ai/docs/completion/batching
|
||||||
|
|
||||||
|
|
||||||
|
LiteLLM Python SDK allows you to:
|
||||||
|
1. `litellm.batch_completion` Batch litellm.completion function for a given model.
|
||||||
|
2. `litellm.batch_completion_models` Send a request to multiple language models concurrently and return the response
|
||||||
|
as soon as one of the models responds.
|
||||||
|
3. `litellm.batch_completion_models_all_responses` Send a request to multiple language models concurrently and return a list of responses
|
||||||
|
from all models that respond.
|
253
litellm/batch_completion/main.py
Normal file
253
litellm/batch_completion/main.py
Normal file
|
@ -0,0 +1,253 @@
|
||||||
|
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import completion
|
||||||
|
from litellm._logging import print_verbose
|
||||||
|
from litellm.utils import get_optional_params
|
||||||
|
|
||||||
|
from ..llms import vllm
|
||||||
|
|
||||||
|
|
||||||
|
def batch_completion(
|
||||||
|
model: str,
|
||||||
|
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||||
|
messages: List = [],
|
||||||
|
functions: Optional[List] = None,
|
||||||
|
function_call: Optional[str] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
stop=None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
logit_bias: Optional[dict] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
deployment_id=None,
|
||||||
|
request_timeout: Optional[int] = None,
|
||||||
|
timeout: Optional[int] = 600,
|
||||||
|
max_workers: Optional[int] = 100,
|
||||||
|
# Optional liteLLM function params
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Batch litellm.completion function for a given model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): The model to use for generating completions.
|
||||||
|
messages (List, optional): List of messages to use as input for generating completions. Defaults to [].
|
||||||
|
functions (List, optional): List of functions to use as input for generating completions. Defaults to [].
|
||||||
|
function_call (str, optional): The function call to use as input for generating completions. Defaults to "".
|
||||||
|
temperature (float, optional): The temperature parameter for generating completions. Defaults to None.
|
||||||
|
top_p (float, optional): The top-p parameter for generating completions. Defaults to None.
|
||||||
|
n (int, optional): The number of completions to generate. Defaults to None.
|
||||||
|
stream (bool, optional): Whether to stream completions or not. Defaults to None.
|
||||||
|
stop (optional): The stop parameter for generating completions. Defaults to None.
|
||||||
|
max_tokens (float, optional): The maximum number of tokens to generate. Defaults to None.
|
||||||
|
presence_penalty (float, optional): The presence penalty for generating completions. Defaults to None.
|
||||||
|
frequency_penalty (float, optional): The frequency penalty for generating completions. Defaults to None.
|
||||||
|
logit_bias (dict, optional): The logit bias for generating completions. Defaults to {}.
|
||||||
|
user (str, optional): The user string for generating completions. Defaults to "".
|
||||||
|
deployment_id (optional): The deployment ID for generating completions. Defaults to None.
|
||||||
|
request_timeout (int, optional): The request timeout for generating completions. Defaults to None.
|
||||||
|
max_workers (int,optional): The maximum number of threads to use for parallel processing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of completion results.
|
||||||
|
"""
|
||||||
|
args = locals()
|
||||||
|
|
||||||
|
batch_messages = messages
|
||||||
|
completions = []
|
||||||
|
model = model
|
||||||
|
custom_llm_provider = None
|
||||||
|
if model.split("/", 1)[0] in litellm.provider_list:
|
||||||
|
custom_llm_provider = model.split("/", 1)[0]
|
||||||
|
model = model.split("/", 1)[1]
|
||||||
|
if custom_llm_provider == "vllm":
|
||||||
|
optional_params = get_optional_params(
|
||||||
|
functions=functions,
|
||||||
|
function_call=function_call,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
n=n,
|
||||||
|
stream=stream or False,
|
||||||
|
stop=stop,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
logit_bias=logit_bias,
|
||||||
|
user=user,
|
||||||
|
# params to identify the model
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
)
|
||||||
|
results = vllm.batch_completions(
|
||||||
|
model=model,
|
||||||
|
messages=batch_messages,
|
||||||
|
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
|
# all non VLLM models for batch completion models
|
||||||
|
else:
|
||||||
|
|
||||||
|
def chunks(lst, n):
|
||||||
|
"""Yield successive n-sized chunks from lst."""
|
||||||
|
for i in range(0, len(lst), n):
|
||||||
|
yield lst[i : i + n]
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
for sub_batch in chunks(batch_messages, 100):
|
||||||
|
for message_list in sub_batch:
|
||||||
|
kwargs_modified = args.copy()
|
||||||
|
kwargs_modified.pop("max_workers")
|
||||||
|
kwargs_modified["messages"] = message_list
|
||||||
|
original_kwargs = {}
|
||||||
|
if "kwargs" in kwargs_modified:
|
||||||
|
original_kwargs = kwargs_modified.pop("kwargs")
|
||||||
|
future = executor.submit(
|
||||||
|
completion, **kwargs_modified, **original_kwargs
|
||||||
|
)
|
||||||
|
completions.append(future)
|
||||||
|
|
||||||
|
# Retrieve the results from the futures
|
||||||
|
# results = [future.result() for future in completions]
|
||||||
|
# return exceptions if any
|
||||||
|
results = []
|
||||||
|
for future in completions:
|
||||||
|
try:
|
||||||
|
results.append(future.result())
|
||||||
|
except Exception as exc:
|
||||||
|
results.append(exc)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# send one request to multiple models
|
||||||
|
# return as soon as one of the llms responds
|
||||||
|
def batch_completion_models(*args, **kwargs):
|
||||||
|
"""
|
||||||
|
Send a request to multiple language models concurrently and return the response
|
||||||
|
as soon as one of the models responds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: Variable-length positional arguments passed to the completion function.
|
||||||
|
**kwargs: Additional keyword arguments:
|
||||||
|
- models (str or list of str): The language models to send requests to.
|
||||||
|
- Other keyword arguments to be passed to the completion function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or None: The response from one of the language models, or None if no response is received.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models.
|
||||||
|
It sends requests concurrently and returns the response from the first model that responds.
|
||||||
|
"""
|
||||||
|
import concurrent
|
||||||
|
|
||||||
|
if "model" in kwargs:
|
||||||
|
kwargs.pop("model")
|
||||||
|
if "models" in kwargs:
|
||||||
|
models = kwargs["models"]
|
||||||
|
kwargs.pop("models")
|
||||||
|
futures = {}
|
||||||
|
with ThreadPoolExecutor(max_workers=len(models)) as executor:
|
||||||
|
for model in models:
|
||||||
|
futures[model] = executor.submit(
|
||||||
|
completion, *args, model=model, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
for model, future in sorted(
|
||||||
|
futures.items(), key=lambda x: models.index(x[0])
|
||||||
|
):
|
||||||
|
if future.result() is not None:
|
||||||
|
return future.result()
|
||||||
|
elif "deployments" in kwargs:
|
||||||
|
deployments = kwargs["deployments"]
|
||||||
|
kwargs.pop("deployments")
|
||||||
|
kwargs.pop("model_list")
|
||||||
|
nested_kwargs = kwargs.pop("kwargs", {})
|
||||||
|
futures = {}
|
||||||
|
with ThreadPoolExecutor(max_workers=len(deployments)) as executor:
|
||||||
|
for deployment in deployments:
|
||||||
|
for key in kwargs.keys():
|
||||||
|
if (
|
||||||
|
key not in deployment
|
||||||
|
): # don't override deployment values e.g. model name, api base, etc.
|
||||||
|
deployment[key] = kwargs[key]
|
||||||
|
kwargs = {**deployment, **nested_kwargs}
|
||||||
|
futures[deployment["model"]] = executor.submit(completion, **kwargs)
|
||||||
|
|
||||||
|
while futures:
|
||||||
|
# wait for the first returned future
|
||||||
|
print_verbose("\n\n waiting for next result\n\n")
|
||||||
|
done, _ = wait(futures.values(), return_when=FIRST_COMPLETED)
|
||||||
|
print_verbose(f"done list\n{done}")
|
||||||
|
for future in done:
|
||||||
|
try:
|
||||||
|
result = future.result()
|
||||||
|
return result
|
||||||
|
except Exception:
|
||||||
|
# if model 1 fails, continue with response from model 2, model3
|
||||||
|
print_verbose(
|
||||||
|
"\n\ngot an exception, ignoring, removing from futures"
|
||||||
|
)
|
||||||
|
print_verbose(futures)
|
||||||
|
new_futures = {}
|
||||||
|
for key, value in futures.items():
|
||||||
|
if future == value:
|
||||||
|
print_verbose(f"removing key{key}")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
new_futures[key] = value
|
||||||
|
futures = new_futures
|
||||||
|
print_verbose(f"new futures{futures}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print_verbose("\n\ndone looping through futures\n\n")
|
||||||
|
print_verbose(futures)
|
||||||
|
|
||||||
|
return None # If no response is received from any model
|
||||||
|
|
||||||
|
|
||||||
|
def batch_completion_models_all_responses(*args, **kwargs):
|
||||||
|
"""
|
||||||
|
Send a request to multiple language models concurrently and return a list of responses
|
||||||
|
from all models that respond.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: Variable-length positional arguments passed to the completion function.
|
||||||
|
**kwargs: Additional keyword arguments:
|
||||||
|
- models (str or list of str): The language models to send requests to.
|
||||||
|
- Other keyword arguments to be passed to the completion function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of responses from the language models that responded.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models.
|
||||||
|
It sends requests concurrently and collects responses from all models that respond.
|
||||||
|
"""
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
# ANSI escape codes for colored output
|
||||||
|
|
||||||
|
if "model" in kwargs:
|
||||||
|
kwargs.pop("model")
|
||||||
|
if "models" in kwargs:
|
||||||
|
models = kwargs["models"]
|
||||||
|
kwargs.pop("models")
|
||||||
|
else:
|
||||||
|
raise Exception("'models' param not in kwargs")
|
||||||
|
|
||||||
|
responses = []
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor:
|
||||||
|
for idx, model in enumerate(models):
|
||||||
|
future = executor.submit(completion, *args, model=model, **kwargs)
|
||||||
|
if future.result() is not None:
|
||||||
|
responses.append(future.result())
|
||||||
|
|
||||||
|
return responses
|
245
litellm/main.py
245
litellm/main.py
|
@ -837,7 +837,7 @@ def completion( # type: ignore
|
||||||
deployments = [
|
deployments = [
|
||||||
m["litellm_params"] for m in model_list if m["model_name"] == model
|
m["litellm_params"] for m in model_list if m["model_name"] == model
|
||||||
]
|
]
|
||||||
return batch_completion_models(deployments=deployments, **args)
|
return litellm.batch_completion_models(deployments=deployments, **args)
|
||||||
if litellm.model_alias_map and model in litellm.model_alias_map:
|
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||||
model = litellm.model_alias_map[
|
model = litellm.model_alias_map[
|
||||||
model
|
model
|
||||||
|
@ -3016,249 +3016,6 @@ async def acompletion_with_retries(*args, **kwargs):
|
||||||
return await retryer(original_function, *args, **kwargs)
|
return await retryer(original_function, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def batch_completion(
|
|
||||||
model: str,
|
|
||||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
|
||||||
messages: List = [],
|
|
||||||
functions: Optional[List] = None,
|
|
||||||
function_call: Optional[str] = None,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
n: Optional[int] = None,
|
|
||||||
stream: Optional[bool] = None,
|
|
||||||
stop=None,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
presence_penalty: Optional[float] = None,
|
|
||||||
frequency_penalty: Optional[float] = None,
|
|
||||||
logit_bias: Optional[dict] = None,
|
|
||||||
user: Optional[str] = None,
|
|
||||||
deployment_id=None,
|
|
||||||
request_timeout: Optional[int] = None,
|
|
||||||
timeout: Optional[int] = 600,
|
|
||||||
max_workers: Optional[int] = 100,
|
|
||||||
# Optional liteLLM function params
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Batch litellm.completion function for a given model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (str): The model to use for generating completions.
|
|
||||||
messages (List, optional): List of messages to use as input for generating completions. Defaults to [].
|
|
||||||
functions (List, optional): List of functions to use as input for generating completions. Defaults to [].
|
|
||||||
function_call (str, optional): The function call to use as input for generating completions. Defaults to "".
|
|
||||||
temperature (float, optional): The temperature parameter for generating completions. Defaults to None.
|
|
||||||
top_p (float, optional): The top-p parameter for generating completions. Defaults to None.
|
|
||||||
n (int, optional): The number of completions to generate. Defaults to None.
|
|
||||||
stream (bool, optional): Whether to stream completions or not. Defaults to None.
|
|
||||||
stop (optional): The stop parameter for generating completions. Defaults to None.
|
|
||||||
max_tokens (float, optional): The maximum number of tokens to generate. Defaults to None.
|
|
||||||
presence_penalty (float, optional): The presence penalty for generating completions. Defaults to None.
|
|
||||||
frequency_penalty (float, optional): The frequency penalty for generating completions. Defaults to None.
|
|
||||||
logit_bias (dict, optional): The logit bias for generating completions. Defaults to {}.
|
|
||||||
user (str, optional): The user string for generating completions. Defaults to "".
|
|
||||||
deployment_id (optional): The deployment ID for generating completions. Defaults to None.
|
|
||||||
request_timeout (int, optional): The request timeout for generating completions. Defaults to None.
|
|
||||||
max_workers (int,optional): The maximum number of threads to use for parallel processing.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: A list of completion results.
|
|
||||||
"""
|
|
||||||
args = locals()
|
|
||||||
|
|
||||||
batch_messages = messages
|
|
||||||
completions = []
|
|
||||||
model = model
|
|
||||||
custom_llm_provider = None
|
|
||||||
if model.split("/", 1)[0] in litellm.provider_list:
|
|
||||||
custom_llm_provider = model.split("/", 1)[0]
|
|
||||||
model = model.split("/", 1)[1]
|
|
||||||
if custom_llm_provider == "vllm":
|
|
||||||
optional_params = get_optional_params(
|
|
||||||
functions=functions,
|
|
||||||
function_call=function_call,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
n=n,
|
|
||||||
stream=stream or False,
|
|
||||||
stop=stop,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
user=user,
|
|
||||||
# params to identify the model
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
)
|
|
||||||
results = vllm.batch_completions(
|
|
||||||
model=model,
|
|
||||||
messages=batch_messages,
|
|
||||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
|
||||||
optional_params=optional_params,
|
|
||||||
)
|
|
||||||
# all non VLLM models for batch completion models
|
|
||||||
else:
|
|
||||||
|
|
||||||
def chunks(lst, n):
|
|
||||||
"""Yield successive n-sized chunks from lst."""
|
|
||||||
for i in range(0, len(lst), n):
|
|
||||||
yield lst[i : i + n]
|
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
||||||
for sub_batch in chunks(batch_messages, 100):
|
|
||||||
for message_list in sub_batch:
|
|
||||||
kwargs_modified = args.copy()
|
|
||||||
kwargs_modified["messages"] = message_list
|
|
||||||
original_kwargs = {}
|
|
||||||
if "kwargs" in kwargs_modified:
|
|
||||||
original_kwargs = kwargs_modified.pop("kwargs")
|
|
||||||
future = executor.submit(
|
|
||||||
completion, **kwargs_modified, **original_kwargs
|
|
||||||
)
|
|
||||||
completions.append(future)
|
|
||||||
|
|
||||||
# Retrieve the results from the futures
|
|
||||||
# results = [future.result() for future in completions]
|
|
||||||
# return exceptions if any
|
|
||||||
results = []
|
|
||||||
for future in completions:
|
|
||||||
try:
|
|
||||||
results.append(future.result())
|
|
||||||
except Exception as exc:
|
|
||||||
results.append(exc)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
# send one request to multiple models
|
|
||||||
# return as soon as one of the llms responds
|
|
||||||
def batch_completion_models(*args, **kwargs):
|
|
||||||
"""
|
|
||||||
Send a request to multiple language models concurrently and return the response
|
|
||||||
as soon as one of the models responds.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
*args: Variable-length positional arguments passed to the completion function.
|
|
||||||
**kwargs: Additional keyword arguments:
|
|
||||||
- models (str or list of str): The language models to send requests to.
|
|
||||||
- Other keyword arguments to be passed to the completion function.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str or None: The response from one of the language models, or None if no response is received.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models.
|
|
||||||
It sends requests concurrently and returns the response from the first model that responds.
|
|
||||||
"""
|
|
||||||
import concurrent
|
|
||||||
|
|
||||||
if "model" in kwargs:
|
|
||||||
kwargs.pop("model")
|
|
||||||
if "models" in kwargs:
|
|
||||||
models = kwargs["models"]
|
|
||||||
kwargs.pop("models")
|
|
||||||
futures = {}
|
|
||||||
with ThreadPoolExecutor(max_workers=len(models)) as executor:
|
|
||||||
for model in models:
|
|
||||||
futures[model] = executor.submit(
|
|
||||||
completion, *args, model=model, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
for model, future in sorted(
|
|
||||||
futures.items(), key=lambda x: models.index(x[0])
|
|
||||||
):
|
|
||||||
if future.result() is not None:
|
|
||||||
return future.result()
|
|
||||||
elif "deployments" in kwargs:
|
|
||||||
deployments = kwargs["deployments"]
|
|
||||||
kwargs.pop("deployments")
|
|
||||||
kwargs.pop("model_list")
|
|
||||||
nested_kwargs = kwargs.pop("kwargs", {})
|
|
||||||
futures = {}
|
|
||||||
with ThreadPoolExecutor(max_workers=len(deployments)) as executor:
|
|
||||||
for deployment in deployments:
|
|
||||||
for key in kwargs.keys():
|
|
||||||
if (
|
|
||||||
key not in deployment
|
|
||||||
): # don't override deployment values e.g. model name, api base, etc.
|
|
||||||
deployment[key] = kwargs[key]
|
|
||||||
kwargs = {**deployment, **nested_kwargs}
|
|
||||||
futures[deployment["model"]] = executor.submit(completion, **kwargs)
|
|
||||||
|
|
||||||
while futures:
|
|
||||||
# wait for the first returned future
|
|
||||||
print_verbose("\n\n waiting for next result\n\n")
|
|
||||||
done, _ = wait(futures.values(), return_when=FIRST_COMPLETED)
|
|
||||||
print_verbose(f"done list\n{done}")
|
|
||||||
for future in done:
|
|
||||||
try:
|
|
||||||
result = future.result()
|
|
||||||
return result
|
|
||||||
except Exception:
|
|
||||||
# if model 1 fails, continue with response from model 2, model3
|
|
||||||
print_verbose(
|
|
||||||
"\n\ngot an exception, ignoring, removing from futures"
|
|
||||||
)
|
|
||||||
print_verbose(futures)
|
|
||||||
new_futures = {}
|
|
||||||
for key, value in futures.items():
|
|
||||||
if future == value:
|
|
||||||
print_verbose(f"removing key{key}")
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
new_futures[key] = value
|
|
||||||
futures = new_futures
|
|
||||||
print_verbose(f"new futures{futures}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
print_verbose("\n\ndone looping through futures\n\n")
|
|
||||||
print_verbose(futures)
|
|
||||||
|
|
||||||
return None # If no response is received from any model
|
|
||||||
|
|
||||||
|
|
||||||
def batch_completion_models_all_responses(*args, **kwargs):
|
|
||||||
"""
|
|
||||||
Send a request to multiple language models concurrently and return a list of responses
|
|
||||||
from all models that respond.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
*args: Variable-length positional arguments passed to the completion function.
|
|
||||||
**kwargs: Additional keyword arguments:
|
|
||||||
- models (str or list of str): The language models to send requests to.
|
|
||||||
- Other keyword arguments to be passed to the completion function.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: A list of responses from the language models that responded.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models.
|
|
||||||
It sends requests concurrently and collects responses from all models that respond.
|
|
||||||
"""
|
|
||||||
import concurrent.futures
|
|
||||||
|
|
||||||
# ANSI escape codes for colored output
|
|
||||||
|
|
||||||
if "model" in kwargs:
|
|
||||||
kwargs.pop("model")
|
|
||||||
if "models" in kwargs:
|
|
||||||
models = kwargs["models"]
|
|
||||||
kwargs.pop("models")
|
|
||||||
else:
|
|
||||||
raise Exception("'models' param not in kwargs")
|
|
||||||
|
|
||||||
responses = []
|
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor:
|
|
||||||
for idx, model in enumerate(models):
|
|
||||||
future = executor.submit(completion, *args, model=model, **kwargs)
|
|
||||||
if future.result() is not None:
|
|
||||||
responses.append(future.result())
|
|
||||||
|
|
||||||
return responses
|
|
||||||
|
|
||||||
|
|
||||||
### EMBEDDING ENDPOINTS ####################
|
### EMBEDDING ENDPOINTS ####################
|
||||||
@client
|
@client
|
||||||
async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
||||||
|
|
|
@ -37,6 +37,9 @@ def test_batch_completions():
|
||||||
print(result)
|
print(result)
|
||||||
print(len(result))
|
print(len(result))
|
||||||
assert len(result) == 3
|
assert len(result) == 3
|
||||||
|
|
||||||
|
for response in result:
|
||||||
|
assert response.choices[0].message.content is not None
|
||||||
except Timeout as e:
|
except Timeout as e:
|
||||||
print(f"IN TIMEOUT")
|
print(f"IN TIMEOUT")
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue