diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 0a5eb19b6..b3acd2e34 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -22,7 +22,8 @@ "ms-python.python", "ms-python.vscode-pylance", "GitHub.copilot", - "GitHub.copilot-chat" + "GitHub.copilot-chat", + "ms-python.autopep8" ] } }, diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d429bc6b8..a33473b72 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,12 @@ repos: - repo: local hooks: - # - id: mypy - # name: mypy - # entry: python3 -m mypy --ignore-missing-imports - # language: system - # types: [python] - # files: ^litellm/ + - id: mypy + name: mypy + entry: python3 -m mypy --ignore-missing-imports + language: system + types: [python] + files: ^litellm/ - id: isort name: isort entry: isort diff --git a/litellm/integrations/langfuse.py b/litellm/integrations/langfuse.py index 0b5e9a4aa..e04230e7e 100644 --- a/litellm/integrations/langfuse.py +++ b/litellm/integrations/langfuse.py @@ -208,6 +208,14 @@ class LangFuseLogger: ): input = prompt output = response_obj["text"] + elif ( + kwargs.get("call_type") is not None + and kwargs.get("call_type") == "pass_through_endpoint" + and response_obj is not None + and isinstance(response_obj, dict) + ): + input = prompt + output = response_obj.get("response", "") print_verbose(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}") trace_id = None generation_id = None diff --git a/litellm/integrations/s3.py b/litellm/integrations/s3.py index c440be5f1..d915100b0 100644 --- a/litellm/integrations/s3.py +++ b/litellm/integrations/s3.py @@ -101,12 +101,6 @@ class S3Logger: metadata = ( litellm_params.get("metadata", {}) or {} ) # if litellm_params['metadata'] == None - messages = kwargs.get("messages") - optional_params = kwargs.get("optional_params", {}) - call_type = kwargs.get("call_type", "litellm.completion") - cache_hit = kwargs.get("cache_hit", False) - usage = response_obj["usage"] - id = response_obj.get("id", str(uuid.uuid4())) # Clean Metadata before logging - never log raw metadata # the raw metadata can contain circular references which leads to infinite recursion @@ -171,5 +165,5 @@ class S3Logger: print_verbose(f"s3 Layer Logging - final response object: {response_obj}") return response except Exception as e: - verbose_logger.debug(f"s3 Layer Error - {str(e)}\n{traceback.format_exc()}") + verbose_logger.exception(f"s3 Layer Error - {str(e)}") pass diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 2ea3f23d3..7af0a1cad 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -41,6 +41,7 @@ from litellm.types.utils import ( StandardLoggingMetadata, StandardLoggingModelInformation, StandardLoggingPayload, + StandardPassThroughResponseObject, TextCompletionResponse, TranscriptionResponse, ) @@ -534,7 +535,9 @@ class Logging: """ ## RESPONSE COST ## custom_pricing = use_custom_pricing_for_model( - litellm_params=self.litellm_params + litellm_params=( + self.litellm_params if hasattr(self, "litellm_params") else None + ) ) if cache_hit is None: @@ -611,6 +614,17 @@ class Logging: ] = result._hidden_params ## STANDARDIZED LOGGING PAYLOAD + self.model_call_details["standard_logging_object"] = ( + get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=result, + start_time=start_time, + end_time=end_time, + logging_obj=self, + ) + ) + elif isinstance(result, dict): # pass-through endpoints + ## STANDARDIZED LOGGING PAYLOAD self.model_call_details["standard_logging_object"] = ( get_standard_logging_object_payload( kwargs=self.model_call_details, @@ -2271,6 +2285,8 @@ def get_standard_logging_object_payload( elif isinstance(init_response_obj, BaseModel): response_obj = init_response_obj.model_dump() hidden_params = getattr(init_response_obj, "_hidden_params", None) + elif isinstance(init_response_obj, dict): + response_obj = init_response_obj else: response_obj = {} # standardize this function to be used across, s3, dynamoDB, langfuse logging diff --git a/litellm/llms/anthropic/chat.py b/litellm/llms/anthropic/chat.py index dd7ab58c1..18e530bb7 100644 --- a/litellm/llms/anthropic/chat.py +++ b/litellm/llms/anthropic/chat.py @@ -674,7 +674,7 @@ async def make_call( timeout: Optional[Union[float, httpx.Timeout]], ): if client is None: - client = _get_async_httpx_client() # Create a new client if none provided + client = litellm.module_level_aclient try: response = await client.post( @@ -690,11 +690,6 @@ async def make_call( raise e raise AnthropicError(status_code=500, message=str(e)) - if response.status_code != 200: - raise AnthropicError( - status_code=response.status_code, message=await response.aread() - ) - completion_stream = ModelResponseIterator( streaming_response=response.aiter_lines(), sync_stream=False ) @@ -721,7 +716,7 @@ def make_sync_call( timeout: Optional[Union[float, httpx.Timeout]], ): if client is None: - client = HTTPHandler() # Create a new client if none provided + client = litellm.module_level_client # re-use a module level client try: response = client.post( @@ -869,6 +864,7 @@ class AnthropicChatCompletion(BaseLLM): model_response: ModelResponse, print_verbose: Callable, timeout: Union[float, httpx.Timeout], + client: Optional[AsyncHTTPHandler], encoding, api_key, logging_obj, @@ -882,19 +878,18 @@ class AnthropicChatCompletion(BaseLLM): ): data["stream"] = True + completion_stream = await make_call( + client=client, + api_base=api_base, + headers=headers, + data=json.dumps(data), + model=model, + messages=messages, + logging_obj=logging_obj, + timeout=timeout, + ) streamwrapper = CustomStreamWrapper( - completion_stream=None, - make_call=partial( - make_call, - client=None, - api_base=api_base, - headers=headers, - data=json.dumps(data), - model=model, - messages=messages, - logging_obj=logging_obj, - timeout=timeout, - ), + completion_stream=completion_stream, model=model, custom_llm_provider="anthropic", logging_obj=logging_obj, @@ -1080,6 +1075,11 @@ class AnthropicChatCompletion(BaseLLM): logger_fn=logger_fn, headers=headers, timeout=timeout, + client=( + client + if client is not None and isinstance(client, AsyncHTTPHandler) + else None + ), ) else: return self.acompletion_function( @@ -1105,33 +1105,32 @@ class AnthropicChatCompletion(BaseLLM): ) else: ## COMPLETION CALL - if client is None or not isinstance(client, HTTPHandler): - client = HTTPHandler(timeout=timeout) # type: ignore - else: - client = client if ( stream is True ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) data["stream"] = stream + completion_stream = make_sync_call( + client=client, + api_base=api_base, + headers=headers, # type: ignore + data=json.dumps(data), + model=model, + messages=messages, + logging_obj=logging_obj, + timeout=timeout, + ) return CustomStreamWrapper( - completion_stream=None, - make_call=partial( - make_sync_call, - client=None, - api_base=api_base, - headers=headers, # type: ignore - data=json.dumps(data), - model=model, - messages=messages, - logging_obj=logging_obj, - timeout=timeout, - ), + completion_stream=completion_stream, model=model, custom_llm_provider="anthropic", logging_obj=logging_obj, ) else: + if client is None or not isinstance(client, HTTPHandler): + client = HTTPHandler(timeout=timeout) # type: ignore + else: + client = client response = client.post( api_base, headers=headers, data=json.dumps(data), timeout=timeout ) diff --git a/litellm/llms/bedrock/embed/embedding.py b/litellm/llms/bedrock/embed/embedding.py index e6a1319b0..a7a6c173c 100644 --- a/litellm/llms/bedrock/embed/embedding.py +++ b/litellm/llms/bedrock/embed/embedding.py @@ -110,7 +110,7 @@ class BedrockEmbedding(BaseAWSLLM): response.raise_for_status() except httpx.HTTPStatusError as err: error_code = err.response.status_code - raise BedrockError(status_code=error_code, message=response.text) + raise BedrockError(status_code=error_code, message=err.response.text) except httpx.TimeoutException: raise BedrockError(status_code=408, message="Timeout error occurred.") diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 939a0605a..2f07ee2f7 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -37,16 +37,10 @@ class AsyncHTTPHandler: # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts. # /path/to/certificate.pem - ssl_verify = os.getenv( - "SSL_VERIFY", - litellm.ssl_verify - ) + ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify) # An SSL certificate used by the requested host to authenticate the client. # /path/to/client.pem - cert = os.getenv( - "SSL_CERTIFICATE", - litellm.ssl_certificate - ) + cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate) if timeout is None: timeout = _DEFAULT_TIMEOUT @@ -277,16 +271,10 @@ class HTTPHandler: # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts. # /path/to/certificate.pem - ssl_verify = os.getenv( - "SSL_VERIFY", - litellm.ssl_verify - ) + ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify) # An SSL certificate used by the requested host to authenticate the client. # /path/to/client.pem - cert = os.getenv( - "SSL_CERTIFICATE", - litellm.ssl_certificate - ) + cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate) if client is None: # Create a client with a connection pool @@ -334,6 +322,7 @@ class HTTPHandler: "POST", url, data=data, json=json, params=params, headers=headers # type: ignore ) response = self.client.send(req, stream=stream) + response.raise_for_status() return response except httpx.TimeoutException: raise litellm.Timeout( @@ -341,6 +330,13 @@ class HTTPHandler: model="default-model-name", llm_provider="litellm-httpx-handler", ) + except httpx.HTTPStatusError as e: + setattr(e, "status_code", e.response.status_code) + if stream is True: + setattr(e, "message", e.response.read()) + else: + setattr(e, "message", e.response.text) + raise e except Exception as e: raise e @@ -375,7 +371,6 @@ class HTTPHandler: except Exception as e: raise e - def __del__(self) -> None: try: self.close() @@ -437,4 +432,4 @@ def _get_httpx_client(params: Optional[dict] = None) -> HTTPHandler: _new_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) litellm.in_memory_llm_clients_cache[_cache_key_name] = _new_client - return _new_client \ No newline at end of file + return _new_client diff --git a/litellm/main.py b/litellm/main.py index bb2c1c47f..cb3555619 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -534,6 +534,15 @@ def mock_completion( model=model, # type: ignore request=httpx.Request(method="POST", url="https://api.openai.com/v1/"), ) + elif isinstance(mock_response, str) and mock_response.startswith( + "Exception: mock_streaming_error" + ): + mock_response = litellm.MockException( + message="This is a mock error raised mid-stream", + llm_provider="anthropic", + model=model, + status_code=529, + ) time_delay = kwargs.get("mock_delay", None) if time_delay is not None: time.sleep(time_delay) @@ -561,6 +570,8 @@ def mock_completion( custom_llm_provider="openai", logging_obj=logging, ) + if isinstance(mock_response, litellm.MockException): + raise mock_response if n is None: model_response.choices[0].message.content = mock_response # type: ignore else: diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index b58725d5f..487e187a3 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -3408,6 +3408,15 @@ "litellm_provider": "bedrock", "mode": "chat" }, + "amazon.titan-text-premier-v1:0": { + "max_tokens": 32000, + "max_input_tokens": 42000, + "max_output_tokens": 32000, + "input_cost_per_token": 0.0000005, + "output_cost_per_token": 0.0000015, + "litellm_provider": "bedrock", + "mode": "chat" + }, "amazon.titan-embed-text-v1": { "max_tokens": 8192, "max_input_tokens": 8192, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 51a995285..335e93447 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,7 +1,16 @@ model_list: - - model_name: "*" + - model_name: "anthropic/claude-3-5-sonnet-20240620" + litellm_params: + model: anthropic/claude-3-5-sonnet-20240620 + # api_base: http://0.0.0.0:9000 + - model_name: gpt-3.5-turbo litellm_params: model: openai/* litellm_settings: - default_max_internal_user_budget: 2 \ No newline at end of file + success_callback: ["s3"] + s3_callback_params: + s3_bucket_name: litellm-logs # AWS Bucket Name for S3 + s3_region_name: us-west-2 # AWS Region Name for S3 + s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/ to pass environment variables. This is AWS Access Key ID for S3 + s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3 \ No newline at end of file diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 67acf71e5..9b559adae 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -872,6 +872,17 @@ class TeamMemberDeleteRequest(LiteLLMBase): return values +class TeamMemberUpdateRequest(TeamMemberDeleteRequest): + max_budget_in_team: float + + +class TeamMemberUpdateResponse(LiteLLMBase): + team_id: str + user_id: str + user_email: Optional[str] = None + max_budget_in_team: float + + class UpdateTeamRequest(LiteLLMBase): """ UpdateTeamRequest, used by /team/update when you need to update a team @@ -1854,3 +1865,10 @@ class LiteLLM_TeamMembership(LiteLLMBase): class TeamAddMemberResponse(LiteLLM_TeamTable): updated_users: List[LiteLLM_UserTable] updated_team_memberships: List[LiteLLM_TeamMembership] + + +class TeamInfoResponseObject(TypedDict): + team_id: str + team_info: TeamBase + keys: List + team_memberships: List[LiteLLM_TeamMembership] diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index 500036bcc..ff13182f7 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -28,8 +28,11 @@ from litellm.proxy._types import ( ProxyErrorTypes, ProxyException, TeamAddMemberResponse, + TeamInfoResponseObject, TeamMemberAddRequest, TeamMemberDeleteRequest, + TeamMemberUpdateRequest, + TeamMemberUpdateResponse, UpdateTeamRequest, UserAPIKeyAuth, ) @@ -750,6 +753,131 @@ async def team_member_delete( return existing_team_row +@router.post( + "/team/member_update", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], + response_model=TeamMemberUpdateResponse, +) +@management_endpoint_wrapper +async def team_member_update( + data: TeamMemberUpdateRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + [BETA] + + Update team member budgets + """ + from litellm.proxy.proxy_server import ( + _duration_in_seconds, + create_audit_log_for_update, + litellm_proxy_admin_name, + prisma_client, + ) + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_id is None: + raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) + + if data.user_id is None and data.user_email is None: + raise HTTPException( + status_code=400, + detail={"error": "Either user_id or user_email needs to be passed in"}, + ) + + _existing_team_row = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": data.team_id} + ) + + if _existing_team_row is None: + raise HTTPException( + status_code=400, + detail={"error": "Team id={} does not exist in db".format(data.team_id)}, + ) + existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump()) + + ## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN + + if ( + user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value + and not _is_user_team_admin( + user_api_key_dict=user_api_key_dict, team_obj=existing_team_row + ) + ): + raise HTTPException( + status_code=403, + detail={ + "error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format( + "/team/member_delete", existing_team_row.team_id + ) + }, + ) + + returned_team_info: TeamInfoResponseObject = await team_info( + http_request=http_request, + team_id=data.team_id, + user_api_key_dict=user_api_key_dict, + ) + + ## get user id + received_user_id: Optional[str] = None + if data.user_id is not None: + received_user_id = data.user_id + elif data.user_email is not None: + for member in returned_team_info["team_info"].members_with_roles: + if member.user_email is not None and member.user_email == data.user_email: + received_user_id = member.user_id + break + + if received_user_id is None: + raise HTTPException( + status_code=400, + detail={ + "error": "User id doesn't exist in team table. Data={}".format(data) + }, + ) + ## find the relevant team membership + identified_budget_id: Optional[str] = None + for tm in returned_team_info["team_memberships"]: + if tm.user_id == received_user_id: + identified_budget_id = tm.budget_id + break + + ### upsert new budget + if identified_budget_id is None: + new_budget = await prisma_client.db.litellm_budgettable.create( + data={ + "max_budget": data.max_budget_in_team, + "created_by": user_api_key_dict.user_id or "", + "updated_by": user_api_key_dict.user_id or "", + } + ) + + await prisma_client.db.litellm_teammembership.create( + data={ + "team_id": data.team_id, + "user_id": received_user_id, + "budget_id": new_budget.budget_id, + }, + ) + else: + await prisma_client.db.litellm_budgettable.update( + where={"budget_id": identified_budget_id}, + data={"max_budget": data.max_budget_in_team}, + ) + + return TeamMemberUpdateResponse( + team_id=data.team_id, + user_id=received_user_id, + user_email=data.user_email, + max_budget_in_team=data.max_budget_in_team, + ) + + @router.post( "/team/delete", tags=["team management"], dependencies=[Depends(user_api_key_auth)] ) @@ -937,12 +1065,18 @@ async def team_info( where={"team_id": team_id}, include={"litellm_budget_table": True}, ) - return { - "team_id": team_id, - "team_info": team_info, - "keys": keys, - "team_memberships": team_memberships, - } + + returned_tm: List[LiteLLM_TeamMembership] = [] + for tm in team_memberships: + returned_tm.append(LiteLLM_TeamMembership(**tm.model_dump())) + + response_object = TeamInfoResponseObject( + team_id=team_id, + team_info=team_info, + keys=keys, + team_memberships=returned_tm, + ) + return response_object except Exception as e: verbose_proxy_logger.error( diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index e36a36d52..388f91ed2 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -359,7 +359,7 @@ async def pass_through_request( start_time = datetime.now() logging_obj = Logging( model="unknown", - messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}], + messages=[{"role": "user", "content": json.dumps(_parsed_body)}], stream=False, call_type="pass_through_endpoint", start_time=start_time, @@ -414,7 +414,7 @@ async def pass_through_request( logging_url = str(url) + "?" + requested_query_params_str logging_obj.pre_call( - input=[{"role": "user", "content": "no-message-pass-through-endpoint"}], + input=[{"role": "user", "content": json.dumps(_parsed_body)}], api_key="", additional_args={ "complete_input_dict": _parsed_body, diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index 8451c7e40..fe46ae58c 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -1,4 +1,6 @@ +import json import re +import threading from datetime import datetime from typing import Union @@ -10,6 +12,7 @@ from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_stu VertexLLM, ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.types.utils import StandardPassThroughResponseObject class PassThroughEndpointLogging: @@ -43,8 +46,24 @@ class PassThroughEndpointLogging: **kwargs, ) else: + standard_logging_response_object = StandardPassThroughResponseObject( + response=httpx_response.text + ) + threading.Thread( + target=logging_obj.success_handler, + args=( + standard_logging_response_object, + start_time, + end_time, + cache_hit, + ), + ).start() await logging_obj.async_success_handler( - result="", + result=( + json.dumps(result) + if isinstance(result, dict) + else standard_logging_response_object + ), start_time=start_time, end_time=end_time, cache_hit=False, diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 8d7c524a4..f4f3a1e58 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3005,13 +3005,13 @@ def model_list( This is just for compatibility with openai projects like aider. """ - global llm_model_list, general_settings + global llm_model_list, general_settings, llm_router all_models = [] ## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ## - if llm_model_list is None: + if llm_router is None: proxy_model_list = [] else: - proxy_model_list = [m["model_name"] for m in llm_model_list] + proxy_model_list = llm_router.get_model_names() key_models = get_key_models( user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list ) @@ -7503,10 +7503,11 @@ async def model_info_v1( all_models: List[dict] = [] ## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ## - if llm_model_list is None: + if llm_router is None: proxy_model_list = [] else: - proxy_model_list = [m["model_name"] for m in llm_model_list] + proxy_model_list = llm_router.get_model_names() + key_models = get_key_models( user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list ) @@ -7523,8 +7524,14 @@ async def model_info_v1( if len(all_models_str) > 0: model_names = all_models_str - _relevant_models = [m for m in llm_model_list if m["model_name"] in model_names] - all_models = copy.deepcopy(_relevant_models) + llm_model_list = llm_router.get_model_list() + if llm_model_list is not None: + _relevant_models = [ + m for m in llm_model_list if m["model_name"] in model_names + ] + all_models = copy.deepcopy(_relevant_models) # type: ignore + else: + all_models = [] for model in all_models: # provided model_info in config.yaml @@ -7590,12 +7597,12 @@ async def model_group_info( raise HTTPException( status_code=500, detail={"error": "LLM Router is not loaded in"} ) - all_models: List[dict] = [] ## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ## - if llm_model_list is None: + if llm_router is None: proxy_model_list = [] else: - proxy_model_list = [m["model_name"] for m in llm_model_list] + proxy_model_list = llm_router.get_model_names() + key_models = get_key_models( user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list ) diff --git a/litellm/router.py b/litellm/router.py index 233331e80..bcd0b6221 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -86,6 +86,7 @@ from litellm.types.router import ( Deployment, DeploymentTypedDict, LiteLLM_Params, + LiteLLMParamsTypedDict, ModelGroupInfo, ModelInfo, RetryPolicy, @@ -4297,7 +4298,9 @@ class Router: return model return None - def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]: + def _set_model_group_info( + self, model_group: str, user_facing_model_group_name: str + ) -> Optional[ModelGroupInfo]: """ For a given model group name, return the combined model info @@ -4379,7 +4382,7 @@ class Router: if model_group_info is None: model_group_info = ModelGroupInfo( - model_group=model_group, providers=[llm_provider], **model_info # type: ignore + model_group=user_facing_model_group_name, providers=[llm_provider], **model_info # type: ignore ) else: # if max_input_tokens > curr @@ -4464,6 +4467,26 @@ class Router: return model_group_info + def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]: + """ + For a given model group name, return the combined model info + + Returns: + - ModelGroupInfo if able to construct a model group + - None if error constructing model group info + """ + ## Check if model group alias + if model_group in self.model_group_alias: + return self._set_model_group_info( + model_group=self.model_group_alias[model_group], + user_facing_model_group_name=model_group, + ) + + ## Check if actual model + return self._set_model_group_info( + model_group=model_group, user_facing_model_group_name=model_group + ) + async def get_model_group_usage( self, model_group: str ) -> Tuple[Optional[int], Optional[int]]: @@ -4534,19 +4557,35 @@ class Router: return ids def get_model_names(self) -> List[str]: - return self.model_names + """ + Returns all possible model names for router. + + Includes model_group_alias models too. + """ + return self.model_names + list(self.model_group_alias.keys()) def get_model_list( self, model_name: Optional[str] = None ) -> Optional[List[DeploymentTypedDict]]: if hasattr(self, "model_list"): - if model_name is None: - return self.model_list - returned_models: List[DeploymentTypedDict] = [] + + for model_alias, model_value in self.model_group_alias.items(): + model_alias_item = DeploymentTypedDict( + model_name=model_alias, + litellm_params=LiteLLMParamsTypedDict(model=model_value), + ) + returned_models.append(model_alias_item) + + if model_name is None: + returned_models += self.model_list + + return returned_models + for model in self.model_list: if model["model_name"] == model_name: returned_models.append(model) + return returned_models return None diff --git a/litellm/tests/test_openai_batches_and_files.py b/litellm/tests/test_openai_batches_and_files.py index 5ac5e4b10..4c55ab8fa 100644 --- a/litellm/tests/test_openai_batches_and_files.py +++ b/litellm/tests/test_openai_batches_and_files.py @@ -22,7 +22,7 @@ import litellm from litellm import create_batch, create_file -@pytest.mark.parametrize("provider", ["openai", "azure"]) +@pytest.mark.parametrize("provider", ["openai"]) # , "azure" def test_create_batch(provider): """ 1. Create File for Batch completion @@ -96,7 +96,7 @@ def test_create_batch(provider): pass -@pytest.mark.parametrize("provider", ["openai", "azure"]) +@pytest.mark.parametrize("provider", ["openai"]) # "azure" @pytest.mark.asyncio() @pytest.mark.flaky(retries=3, delay=1) async def test_async_create_batch(provider): diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 1a8cb831e..fd89130fe 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -2396,6 +2396,7 @@ async def test_router_weighted_pick(sync_mode): assert model_id_1_count > model_id_2_count +@pytest.mark.skip(reason="Hit azure batch quota limits") @pytest.mark.parametrize("provider", ["azure"]) @pytest.mark.asyncio async def test_router_batch_endpoints(provider): diff --git a/litellm/tests/test_router_fallbacks.py b/litellm/tests/test_router_fallbacks.py index ec5b43735..97b34b490 100644 --- a/litellm/tests/test_router_fallbacks.py +++ b/litellm/tests/test_router_fallbacks.py @@ -12,6 +12,7 @@ import pytest sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path +from unittest.mock import AsyncMock, MagicMock, patch import litellm from litellm import Router @@ -1266,3 +1267,73 @@ async def test_using_default_working_fallback(sync_mode): ) print("got response=", response) assert response is not None + + +# asyncio.run(test_acompletion_gemini_stream()) +def mock_post_streaming(url, **kwargs): + mock_response = MagicMock() + mock_response.status_code = 529 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.return_value = {"detail": "Overloaded!"} + + return mock_response + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_anthropic_streaming_fallbacks(sync_mode): + litellm.set_verbose = True + from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler + + if sync_mode: + client = HTTPHandler(concurrent_limit=1) + else: + client = AsyncHTTPHandler(concurrent_limit=1) + + router = Router( + model_list=[ + { + "model_name": "anthropic/claude-3-5-sonnet-20240620", + "litellm_params": { + "model": "anthropic/claude-3-5-sonnet-20240620", + }, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "mock_response": "Hey, how's it going?", + }, + }, + ], + fallbacks=[{"anthropic/claude-3-5-sonnet-20240620": ["gpt-3.5-turbo"]}], + num_retries=0, + ) + + with patch.object(client, "post", side_effect=mock_post_streaming) as mock_client: + chunks = [] + if sync_mode: + response = router.completion( + model="anthropic/claude-3-5-sonnet-20240620", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + stream=True, + client=client, + ) + for chunk in response: + print(chunk) + chunks.append(chunk) + else: + response = await router.acompletion( + model="anthropic/claude-3-5-sonnet-20240620", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + stream=True, + client=client, + ) + async for chunk in response: + print(chunk) + chunks.append(chunk) + print(f"RETURNED response: {response}") + + mock_client.assert_called_once() + print(chunks) + assert len(chunks) > 0 diff --git a/litellm/tests/test_sagemaker.py b/litellm/tests/test_sagemaker.py index f0b77af4f..effffedf0 100644 --- a/litellm/tests/test_sagemaker.py +++ b/litellm/tests/test_sagemaker.py @@ -127,6 +127,7 @@ async def test_completion_sagemaker_messages_api(sync_mode): "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", ], ) +@pytest.mark.flaky(retries=3, delay=1) async def test_completion_sagemaker_stream(sync_mode, model): try: from litellm.tests.test_streaming import streaming_format_tests diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 43313b7f7..7a56d0703 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -3144,7 +3144,6 @@ async def test_azure_astreaming_and_function_calling(): def test_completion_claude_3_function_call_with_streaming(): - litellm.set_verbose = True tools = [ { "type": "function", @@ -3827,6 +3826,65 @@ def test_unit_test_custom_stream_wrapper_function_call(): assert len(new_model.choices[0].delta.tool_calls) > 0 +def test_unit_test_perplexity_citations_chunk(): + """ + Test if model returns a tool call, the finish reason is correctly set to 'tool_calls' + """ + from litellm.types.llms.openai import ChatCompletionDeltaChunk + + litellm.set_verbose = False + delta: ChatCompletionDeltaChunk = { + "content": "B", + "role": "assistant", + } + chunk = { + "id": "xxx", + "model": "llama-3.1-sonar-small-128k-online", + "created": 1725494279, + "usage": {"prompt_tokens": 15, "completion_tokens": 1, "total_tokens": 16}, + "citations": [ + "https://x.com/bizzabo?lang=ur", + "https://apps.apple.com/my/app/bizzabo/id408705047", + "https://www.bizzabo.com/blog/maximize-event-data-strategies-for-success", + ], + "object": "chat.completion", + "choices": [ + { + "index": 0, + "finish_reason": None, + "message": {"role": "assistant", "content": "B"}, + "delta": delta, + } + ], + } + chunk = litellm.ModelResponse(**chunk, stream=True) + + completion_stream = ModelResponseIterator(model_response=chunk) + + response = litellm.CustomStreamWrapper( + completion_stream=completion_stream, + model="gpt-3.5-turbo", + custom_llm_provider="cached_response", + logging_obj=litellm.litellm_core_utils.litellm_logging.Logging( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey"}], + stream=True, + call_type="completion", + start_time=time.time(), + litellm_call_id="12345", + function_id="1245", + ), + ) + + finish_reason: Optional[str] = None + for response_chunk in response: + if response_chunk.choices[0].delta.content is not None: + print( + f"response_chunk.choices[0].delta.content: {response_chunk.choices[0].delta.content}" + ) + assert "citations" in response_chunk + + @pytest.mark.parametrize( "model", [ diff --git a/litellm/types/utils.py b/litellm/types/utils.py index e9fe7d963..96c25e5b5 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1297,3 +1297,7 @@ class CustomStreamingDecoder: self, iterator: Iterator[bytes] ) -> Iterator[Optional[Union[GenericStreamingChunk, StreamingChatCompletionChunk]]]: raise NotImplementedError + + +class StandardPassThroughResponseObject(TypedDict): + response: str diff --git a/litellm/utils.py b/litellm/utils.py index c362a7b5a..d8aa51bd5 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -9845,6 +9845,9 @@ class CustomStreamWrapper: model_response.system_fingerprint = ( original_chunk.system_fingerprint ) + model_response.citations = getattr( + original_chunk, "citations", None + ) print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}") if self.sent_first_chunk is False: model_response.choices[0].delta["role"] = "assistant" @@ -10460,6 +10463,8 @@ class TextCompletionStreamWrapper: def mock_completion_streaming_obj( model_response, mock_response, model, n: Optional[int] = None ): + if isinstance(mock_response, litellm.MockException): + raise mock_response for i in range(0, len(mock_response), 3): completion_obj = Delta(role="assistant", content=mock_response[i : i + 3]) if n is None: @@ -10481,6 +10486,8 @@ def mock_completion_streaming_obj( async def async_mock_completion_streaming_obj( model_response, mock_response, model, n: Optional[int] = None ): + if isinstance(mock_response, litellm.MockException): + raise mock_response for i in range(0, len(mock_response), 3): completion_obj = Delta(role="assistant", content=mock_response[i : i + 3]) if n is None: diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index b58725d5f..487e187a3 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -3408,6 +3408,15 @@ "litellm_provider": "bedrock", "mode": "chat" }, + "amazon.titan-text-premier-v1:0": { + "max_tokens": 32000, + "max_input_tokens": 42000, + "max_output_tokens": 32000, + "input_cost_per_token": 0.0000005, + "output_cost_per_token": 0.0000015, + "litellm_provider": "bedrock", + "mode": "chat" + }, "amazon.titan-embed-text-v1": { "max_tokens": 8192, "max_input_tokens": 8192,