diff --git a/litellm/__init__.py b/litellm/__init__.py index ffaf1a2be..2c68fa9af 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -140,6 +140,7 @@ return_response_headers: bool = ( enable_json_schema_validation: bool = False ################## logging: bool = True +enable_loadbalancing_on_batch_endpoints: Optional[bool] = None enable_caching_on_provider_specific_optional_params: bool = ( False # feature-flag for caching on optional params - e.g. 'top_k' ) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 2c3fcfa1b..420b011a4 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,17 +1,12 @@ model_list: - - model_name: "gpt-3.5-turbo" + - model_name: "batch-gpt-4o-mini" litellm_params: - model: "gpt-3.5-turbo" + model: "azure/gpt-4o-mini" + api_key: os.environ/AZURE_API_KEY + api_base: os.environ/AZURE_API_BASE + model_info: + mode: batch litellm_settings: - max_internal_user_budget: 0.02 # amount in USD - internal_user_budget_duration: "1s" # reset every second - -general_settings: - master_key: sk-1234 - alerting: ["slack"] - alerting_threshold: 0.0001 # (Seconds) set an artifically low threshold for testing alerting - alert_to_webhook_url: { - "spend_reports": ["https://webhook.site/7843a980-a494-4967-80fb-d502dbc16886", "https://webhook.site/28cfb179-f4fb-4408-8129-729ff55cf213"] - } - + enable_loadbalancing_on_batch_endpoints: true + diff --git a/litellm/proxy/openai_files_endpoints/files_endpoints.py b/litellm/proxy/openai_files_endpoints/files_endpoints.py index f4400b682..fa82c2a2d 100644 --- a/litellm/proxy/openai_files_endpoints/files_endpoints.py +++ b/litellm/proxy/openai_files_endpoints/files_endpoints.py @@ -31,6 +31,7 @@ from litellm._logging import verbose_proxy_logger from litellm.batches.main import FileObject from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.router import Router router = APIRouter() @@ -66,6 +67,41 @@ def get_files_provider_config( return None +def get_first_json_object(file_content_bytes: bytes) -> Optional[dict]: + try: + # Decode the bytes to a string and split into lines + file_content = file_content_bytes.decode("utf-8") + first_line = file_content.splitlines()[0].strip() + + # Parse the JSON object from the first line + json_object = json.loads(first_line) + return json_object + except (json.JSONDecodeError, UnicodeDecodeError) as e: + return None + + +def get_model_from_json_obj(json_object: dict) -> Optional[str]: + body = json_object.get("body", {}) or {} + model = body.get("model") + + return model + + +def is_known_model(model: Optional[str], llm_router: Optional[Router]) -> bool: + """ + Returns True if the model is in the llm_router model names + """ + if model is None or llm_router is None: + return False + model_names = llm_router.get_model_names() + + is_in_list = False + if model in model_names: + is_in_list = True + + return is_in_list + + @router.post( "/{provider}/v1/files", dependencies=[Depends(user_api_key_auth)], @@ -109,6 +145,7 @@ async def create_file( add_litellm_data_to_request, general_settings, get_custom_headers, + llm_router, proxy_config, proxy_logging_obj, version, @@ -138,18 +175,46 @@ async def create_file( # Prepare the file data according to FileTypes file_data = (file.filename, file_content, file.content_type) + ## check if model is a loadbalanced model + router_model: Optional[str] = None + is_router_model = False + if litellm.enable_loadbalancing_on_batch_endpoints is True: + json_obj = get_first_json_object(file_content_bytes=file_content) + if json_obj: + router_model = get_model_from_json_obj(json_object=json_obj) + is_router_model = is_known_model( + model=router_model, llm_router=llm_router + ) + _create_file_request = CreateFileRequest(file=file_data, **data) - # get configs for custom_llm_provider - llm_provider_config = get_files_provider_config( - custom_llm_provider=custom_llm_provider - ) + if ( + litellm.enable_loadbalancing_on_batch_endpoints is True + and is_router_model + and router_model is not None + ): + if llm_router is None: + raise HTTPException( + status_code=500, + detail={ + "error": "LLM Router not initialized. Ensure models added to proxy." + }, + ) - # add llm_provider_config to data - _create_file_request.update(llm_provider_config) + response = await llm_router.acreate_file( + model=router_model, **_create_file_request + ) + else: + # get configs for custom_llm_provider + llm_provider_config = get_files_provider_config( + custom_llm_provider=custom_llm_provider + ) - # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch - response = await litellm.acreate_file(**_create_file_request) # type: ignore + # add llm_provider_config to data + _create_file_request.update(llm_provider_config) + + # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch + response = await litellm.acreate_file(**_create_file_request) # type: ignore ### ALERTING ### asyncio.create_task( diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index fbefa867b..dd6869c66 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -199,6 +199,7 @@ from litellm.proxy.management_endpoints.team_callback_endpoints import ( router as team_callback_router, ) from litellm.proxy.management_endpoints.team_endpoints import router as team_router +from litellm.proxy.openai_files_endpoints.files_endpoints import is_known_model from litellm.proxy.openai_files_endpoints.files_endpoints import ( router as openai_files_router, ) @@ -4979,13 +4980,35 @@ async def create_batch( proxy_config=proxy_config, ) + ## check if model is a loadbalanced model + router_model: Optional[str] = None + is_router_model = False + if litellm.enable_loadbalancing_on_batch_endpoints is True: + router_model = data.get("model", None) + is_router_model = is_known_model(model=router_model, llm_router=llm_router) + _create_batch_data = CreateBatchRequest(**data) - if provider is None: - provider = "openai" - response = await litellm.acreate_batch( - custom_llm_provider=provider, **_create_batch_data # type: ignore - ) + if ( + litellm.enable_loadbalancing_on_batch_endpoints is True + and is_router_model + and router_model is not None + ): + if llm_router is None: + raise HTTPException( + status_code=500, + detail={ + "error": "LLM Router not initialized. Ensure models added to proxy." + }, + ) + + response = await llm_router.acreate_batch(**_create_batch_data) # type: ignore + else: + if provider is None: + provider = "openai" + response = await litellm.acreate_batch( + custom_llm_provider=provider, **_create_batch_data # type: ignore + ) ### ALERTING ### asyncio.create_task( @@ -5017,7 +5040,7 @@ async def create_batch( await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) - verbose_proxy_logger.error( + verbose_proxy_logger.exception( "litellm.proxy.proxy_server.create_batch(): Exception occured - {}".format( str(e) ) @@ -5080,15 +5103,30 @@ async def retrieve_batch( global proxy_logging_obj data: Dict = {} try: + ## check if model is a loadbalanced model + router_model: Optional[str] = None + is_router_model = False + _retrieve_batch_request = RetrieveBatchRequest( batch_id=batch_id, ) - if provider is None: - provider = "openai" - response = await litellm.aretrieve_batch( - custom_llm_provider=provider, **_retrieve_batch_request # type: ignore - ) + if litellm.enable_loadbalancing_on_batch_endpoints is True: + if llm_router is None: + raise HTTPException( + status_code=500, + detail={ + "error": "LLM Router not initialized. Ensure models added to proxy." + }, + ) + + response = await llm_router.aretrieve_batch(**_retrieve_batch_request) # type: ignore + else: + if provider is None: + provider = "openai" + response = await litellm.aretrieve_batch( + custom_llm_provider=provider, **_retrieve_batch_request # type: ignore + ) ### ALERTING ### asyncio.create_task( @@ -5120,7 +5158,7 @@ async def retrieve_batch( await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) - verbose_proxy_logger.error( + verbose_proxy_logger.exception( "litellm.proxy.proxy_server.retrieve_batch(): Exception occured - {}".format( str(e) ) diff --git a/litellm/router.py b/litellm/router.py index 1a433858d..15fb4cb27 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -54,6 +54,10 @@ from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2 from litellm.router_strategy.tag_based_routing import get_deployments_for_tag +from litellm.router_utils.batch_utils import ( + _get_router_metadata_variable_name, + replace_model_in_jsonl, +) from litellm.router_utils.client_initalization_utils import ( set_client, should_initialize_sync_client, @@ -73,6 +77,12 @@ from litellm.types.llms.openai import ( AssistantToolParam, AsyncCursorPage, Attachment, + Batch, + CreateFileRequest, + FileContentRequest, + FileObject, + FileTypes, + HttpxBinaryResponseContent, OpenAIMessage, Run, Thread, @@ -103,6 +113,7 @@ from litellm.utils import ( _is_region_eu, calculate_max_parallel_requests, create_proxy_transport_and_mounts, + get_llm_provider, get_utc_datetime, ) @@ -2228,6 +2239,373 @@ class Router: self.fail_calls[model_name] += 1 raise e + #### FILES API #### + + async def acreate_file( + self, + model: str, + **kwargs, + ) -> FileObject: + try: + kwargs["model"] = model + kwargs["original_function"] = self._acreate_file + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + timeout = kwargs.get("request_timeout", self.timeout) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + response = await self.async_function_with_fallbacks(**kwargs) + + return response + except Exception as e: + asyncio.create_task( + send_llm_exception_alert( + litellm_router_instance=self, + request_kwargs=kwargs, + error_traceback_str=traceback.format_exc(), + original_exception=e, + ) + ) + raise e + + async def _acreate_file( + self, + model: str, + **kwargs, + ) -> FileObject: + try: + verbose_router_logger.debug( + f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}" + ) + deployment = await self.async_get_available_deployment( + model=model, + messages=[{"role": "user", "content": "files-api-fake-text"}], + specific_deployment=kwargs.pop("specific_deployment", None), + ) + kwargs.setdefault("metadata", {}).update( + { + "deployment": deployment["litellm_params"]["model"], + "model_info": deployment.get("model_info", {}), + "api_base": deployment.get("litellm_params", {}).get("api_base"), + } + ) + kwargs["model_info"] = deployment.get("model_info", {}) + data = deployment["litellm_params"].copy() + model_name = data["model"] + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs + ): # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) + + potential_model_client = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="async" + ) + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) + if ( + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None + else: + model_client = potential_model_client + self.total_calls[model_name] += 1 + + ## REPLACE MODEL IN FILE WITH SELECTED DEPLOYMENT ## + stripped_model, custom_llm_provider, _, _ = get_llm_provider( + model=data["model"] + ) + kwargs["file"] = replace_model_in_jsonl( + file_content=kwargs["file"], new_model_name=stripped_model + ) + + response = litellm.acreate_file( + **{ + **data, + "custom_llm_provider": custom_llm_provider, + "caching": self.cache_responses, + "client": model_client, + "timeout": self.timeout, + **kwargs, + } + ) + + rpm_semaphore = self._get_client( + deployment=deployment, + kwargs=kwargs, + client_type="max_parallel_requests", + ) + + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore + ): + async with rpm_semaphore: + """ + - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) + """ + await self.async_routing_strategy_pre_call_checks( + deployment=deployment + ) + response = await response # type: ignore + else: + await self.async_routing_strategy_pre_call_checks(deployment=deployment) + response = await response # type: ignore + + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"litellm.acreate_file(model={model_name})\033[32m 200 OK\033[0m" + ) + return response # type: ignore + except Exception as e: + verbose_router_logger.exception( + f"litellm.acreate_file(model={model}, {kwargs})\033[31m Exception {str(e)}\033[0m" + ) + if model is not None: + self.fail_calls[model] += 1 + raise e + + async def acreate_batch( + self, + model: str, + **kwargs, + ) -> Batch: + try: + kwargs["model"] = model + kwargs["original_function"] = self._acreate_batch + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + timeout = kwargs.get("request_timeout", self.timeout) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + response = await self.async_function_with_fallbacks(**kwargs) + + return response + except Exception as e: + asyncio.create_task( + send_llm_exception_alert( + litellm_router_instance=self, + request_kwargs=kwargs, + error_traceback_str=traceback.format_exc(), + original_exception=e, + ) + ) + raise e + + async def _acreate_batch( + self, + model: str, + **kwargs, + ) -> Batch: + try: + verbose_router_logger.debug( + f"Inside _acreate_batch()- model: {model}; kwargs: {kwargs}" + ) + deployment = await self.async_get_available_deployment( + model=model, + messages=[{"role": "user", "content": "files-api-fake-text"}], + specific_deployment=kwargs.pop("specific_deployment", None), + ) + metadata_variable_name = _get_router_metadata_variable_name( + function_name="_acreate_batch" + ) + + kwargs.setdefault(metadata_variable_name, {}).update( + { + "deployment": deployment["litellm_params"]["model"], + "model_info": deployment.get("model_info", {}), + "api_base": deployment.get("litellm_params", {}).get("api_base"), + } + ) + kwargs["model_info"] = deployment.get("model_info", {}) + data = deployment["litellm_params"].copy() + model_name = data["model"] + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs + ): # prioritize model-specific params > default router params + kwargs[k] = v + elif k == metadata_variable_name: + kwargs[k].update(v) + + potential_model_client = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="async" + ) + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) + if ( + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None + else: + model_client = potential_model_client + self.total_calls[model_name] += 1 + + ## SET CUSTOM PROVIDER TO SELECTED DEPLOYMENT ## + _, custom_llm_provider, _, _ = get_llm_provider(model=data["model"]) + + response = litellm.acreate_batch( + **{ + **data, + "custom_llm_provider": custom_llm_provider, + "caching": self.cache_responses, + "client": model_client, + "timeout": self.timeout, + **kwargs, + } + ) + + rpm_semaphore = self._get_client( + deployment=deployment, + kwargs=kwargs, + client_type="max_parallel_requests", + ) + + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore + ): + async with rpm_semaphore: + """ + - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) + """ + await self.async_routing_strategy_pre_call_checks( + deployment=deployment + ) + response = await response # type: ignore + else: + await self.async_routing_strategy_pre_call_checks(deployment=deployment) + response = await response # type: ignore + + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"litellm.acreate_file(model={model_name})\033[32m 200 OK\033[0m" + ) + return response # type: ignore + except Exception as e: + verbose_router_logger.exception( + f"litellm._acreate_batch(model={model}, {kwargs})\033[31m Exception {str(e)}\033[0m" + ) + if model is not None: + self.fail_calls[model] += 1 + raise e + + async def aretrieve_batch( + self, + **kwargs, + ) -> Batch: + """ + Iterate through all models in a model group to check for batch + + Future Improvement - cache the result. + """ + try: + + filtered_model_list = self.get_model_list() + if filtered_model_list is None: + raise Exception("Router not yet initialized.") + + receieved_exceptions = [] + + async def try_retrieve_batch(model_name): + try: + # Update kwargs with the current model name or any other model-specific adjustments + ## SET CUSTOM PROVIDER TO SELECTED DEPLOYMENT ## + _, custom_llm_provider, _, _ = get_llm_provider( + model=model_name["litellm_params"]["model"] + ) + new_kwargs = copy.deepcopy(kwargs) + new_kwargs.pop("custom_llm_provider", None) + return await litellm.aretrieve_batch( + custom_llm_provider=custom_llm_provider, **new_kwargs + ) + except Exception as e: + receieved_exceptions.append(e) + return None + + # Check all models in parallel + results = await asyncio.gather( + *[try_retrieve_batch(model) for model in filtered_model_list], + return_exceptions=True, + ) + + # Check for successful responses and handle exceptions + for result in results: + if isinstance(result, Batch): + return result + + # If no valid Batch response was found, raise the first encountered exception + if receieved_exceptions: + raise receieved_exceptions[0] # Raising the first exception encountered + + # If no exceptions were encountered, raise a generic exception + raise Exception( + "Unable to find batch in any model. Received errors - {}".format( + receieved_exceptions + ) + ) + except Exception as e: + asyncio.create_task( + send_llm_exception_alert( + litellm_router_instance=self, + request_kwargs=kwargs, + error_traceback_str=traceback.format_exc(), + original_exception=e, + ) + ) + raise e + + async def alist_batches( + self, + model: str, + **kwargs, + ): + """ + Return all the batches across all deployments of a model group. + """ + + filtered_model_list = self.get_model_list(model_name=model) + if filtered_model_list is None: + raise Exception("Router not yet initialized.") + + async def try_retrieve_batch(model: DeploymentTypedDict): + try: + # Update kwargs with the current model name or any other model-specific adjustments + return await litellm.alist_batches( + **{**model["litellm_params"], **kwargs} + ) + except Exception as e: + return None + + # Check all models in parallel + results = await asyncio.gather( + *[try_retrieve_batch(model) for model in filtered_model_list] + ) + + final_results = { + "object": "list", + "data": [], + "first_id": None, + "last_id": None, + "has_more": False, + } + + for result in results: + if result is not None: + ## check batch id + if final_results["first_id"] is None: + final_results["first_id"] = result.first_id + final_results["last_id"] = result.last_id + final_results["data"].extend(result.data) # type: ignore + + ## check 'has_more' + if result.has_more is True: + final_results["has_more"] = True + + return final_results + #### ASSISTANTS API #### async def acreate_assistants( @@ -4132,9 +4510,18 @@ class Router: def get_model_names(self) -> List[str]: return self.model_names - def get_model_list(self): + def get_model_list( + self, model_name: Optional[str] = None + ) -> Optional[List[DeploymentTypedDict]]: if hasattr(self, "model_list"): - return self.model_list + if model_name is None: + return self.model_list + + returned_models: List[DeploymentTypedDict] = [] + for model in self.model_list: + if model["model_name"] == model_name: + returned_models.append(model) + return returned_models return None def get_model_access_groups(self): diff --git a/litellm/router_utils/batch_utils.py b/litellm/router_utils/batch_utils.py new file mode 100644 index 000000000..af080643f --- /dev/null +++ b/litellm/router_utils/batch_utils.py @@ -0,0 +1,59 @@ +import io +import json +from typing import IO, Optional, Tuple, Union + + +class InMemoryFile(io.BytesIO): + def __init__(self, content: bytes, name: str): + super().__init__(content) + self.name = name + + +def replace_model_in_jsonl( + file_content: Union[bytes, IO, Tuple[str, bytes, str]], new_model_name: str +) -> Optional[InMemoryFile]: + try: + # Decode the bytes to a string and split into lines + # If file_content is a file-like object, read the bytes + if hasattr(file_content, "read"): + file_content_bytes = file_content.read() # type: ignore + elif isinstance(file_content, tuple): + file_content_bytes = file_content[1] + else: + file_content_bytes = file_content + + # Decode the bytes to a string and split into lines + file_content_str = file_content_bytes.decode("utf-8") + lines = file_content_str.splitlines() + modified_lines = [] + for line in lines: + # Parse each line as a JSON object + json_object = json.loads(line.strip()) + + # Replace the model name if it exists + if "body" in json_object: + json_object["body"]["model"] = new_model_name + + # Convert the modified JSON object back to a string + modified_lines.append(json.dumps(json_object)) + + # Reassemble the modified lines and return as bytes + modified_file_content = "\n".join(modified_lines).encode("utf-8") + return InMemoryFile(modified_file_content, name="modified_file.jsonl") # type: ignore + + except (json.JSONDecodeError, UnicodeDecodeError, TypeError) as e: + return None + + +def _get_router_metadata_variable_name(function_name) -> str: + """ + Helper to return what the "metadata" field should be called in the request data + + For all /thread or /assistant endpoints we need to call this "litellm_metadata" + + For ALL other endpoints we call this "metadata + """ + if "batch" in function_name: + return "litellm_metadata" + else: + return "metadata" diff --git a/litellm/tests/openai_batch_completions_router.jsonl b/litellm/tests/openai_batch_completions_router.jsonl new file mode 100644 index 000000000..8a4c99ca8 --- /dev/null +++ b/litellm/tests/openai_batch_completions_router.jsonl @@ -0,0 +1,3 @@ +{"custom_id": "task-0", "method": "POST", "url": "/chat/completions", "body": {"model": "my-custom-name", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was Microsoft founded?"}]}} +{"custom_id": "task-1", "method": "POST", "url": "/chat/completions", "body": {"model": "my-custom-name", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was the first XBOX released?"}]}} +{"custom_id": "task-2", "method": "POST", "url": "/chat/completions", "body": {"model": "my-custom-name", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "What is Altair Basic?"}]}} \ No newline at end of file diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index a0a96fa67..1a8cb831e 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -2394,3 +2394,83 @@ async def test_router_weighted_pick(sync_mode): else: raise Exception("invalid model id returned!") assert model_id_1_count > model_id_2_count + + +@pytest.mark.parametrize("provider", ["azure"]) +@pytest.mark.asyncio +async def test_router_batch_endpoints(provider): + """ + 1. Create File for Batch completion + 2. Create Batch Request + 3. Retrieve the specific batch + """ + print("Testing async create batch") + + router = Router( + model_list=[ + { + "model_name": "my-custom-name", + "litellm_params": { + "model": "azure/gpt-4o-mini", + "api_base": os.getenv("AZURE_API_BASE"), + "api_key": os.getenv("AZURE_API_KEY"), + }, + }, + ] + ) + + file_name = "openai_batch_completions_router.jsonl" + _current_dir = os.path.dirname(os.path.abspath(__file__)) + file_path = os.path.join(_current_dir, file_name) + file_obj = await router.acreate_file( + model="my-custom-name", + file=open(file_path, "rb"), + purpose="batch", + custom_llm_provider=provider, + ) + print("Response from creating file=", file_obj) + + await asyncio.sleep(10) + batch_input_file_id = file_obj.id + assert ( + batch_input_file_id is not None + ), "Failed to create file, expected a non null file_id but got {batch_input_file_id}" + + create_batch_response = await router.acreate_batch( + model="my-custom-name", + completion_window="24h", + endpoint="/v1/chat/completions", + input_file_id=batch_input_file_id, + custom_llm_provider=provider, + metadata={"key1": "value1", "key2": "value2"}, + ) + + print("response from router.create_batch=", create_batch_response) + + assert ( + create_batch_response.id is not None + ), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}" + assert ( + create_batch_response.endpoint == "/v1/chat/completions" + or create_batch_response.endpoint == "/chat/completions" + ), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}" + assert ( + create_batch_response.input_file_id == batch_input_file_id + ), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}" + + await asyncio.sleep(1) + + retrieved_batch = await router.aretrieve_batch( + batch_id=create_batch_response.id, + custom_llm_provider=provider, + ) + print("retrieved batch=", retrieved_batch) + # just assert that we retrieved a non None batch + + assert retrieved_batch.id == create_batch_response.id + + # list all batches + list_batches = await router.alist_batches( + model="my-custom-name", custom_llm_provider=provider, limit=2 + ) + print("list_batches=", list_batches) diff --git a/litellm/utils.py b/litellm/utils.py index a7aadbf3f..26bf993aa 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4645,6 +4645,8 @@ def get_llm_provider( For router -> Can also give the whole litellm param dict -> this function will extract the relevant details Raises Error - if unable to map model to a provider + + Return model, custom_llm_provider, dynamic_api_key, api_base """ try: ## IF LITELLM PARAMS GIVEN ## diff --git a/tests/openai_batch_completions.jsonl b/tests/openai_batch_completions.jsonl index 05448952a..8b17a304a 100644 --- a/tests/openai_batch_completions.jsonl +++ b/tests/openai_batch_completions.jsonl @@ -1,2 +1,2 @@ -{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}} -{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}} \ No newline at end of file +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "my-custom-name", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "my-custom-name", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}} \ No newline at end of file