LiteLLM Minor Fixes and Improvements (08/06/2024) (#5567)

* fix(utils.py): return citations for perplexity streaming

Fixes https://github.com/BerriAI/litellm/issues/5535

* fix(anthropic/chat.py): support fallbacks for anthropic streaming (#5542)

* fix(anthropic/chat.py): support fallbacks for anthropic streaming

Fixes https://github.com/BerriAI/litellm/issues/5512

* fix(anthropic/chat.py): use module level http client if none given (prevents early client closure)

* fix: fix linting errors

* fix(http_handler.py): fix raise_for_status error handling

* test: retry flaky test

* fix otel type

* fix(bedrock/embed): fix error raising

* test(test_openai_batches_and_files.py): skip azure batches test (for now) quota exceeded

* fix(test_router.py): skip azure batch route test (for now) - hit batch quota limits

---------

Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>

* All `model_group_alias` should show up in `/models`, `/model/info` , `/model_group/info` (#5539)

* fix(router.py): support returning model_alias model names in `/v1/models`

* fix(proxy_server.py): support returning model alias'es on `/model/info`

* feat(router.py): support returning model group alias for `/model_group/info`

* fix(proxy_server.py): fix linting errors

* fix(proxy_server.py): fix linting errors

* build(model_prices_and_context_window.json): add amazon titan text premier pricing information

Closes https://github.com/BerriAI/litellm/issues/5560

* feat(litellm_logging.py): log standard logging response object for pass through endpoints. Allows bedrock /invoke agent calls to be correctly logged to langfuse + s3

* fix(success_handler.py): fix linting error

* fix(success_handler.py): fix linting errors

* fix(team_endpoints.py): Allows admin to update team member budgets

---------

Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
This commit is contained in:
Krish Dholakia 2024-09-06 17:16:24 -07:00 committed by GitHub
parent e4dcd6f745
commit 72e961af3c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 509 additions and 99 deletions

View file

@ -22,7 +22,8 @@
"ms-python.python",
"ms-python.vscode-pylance",
"GitHub.copilot",
"GitHub.copilot-chat"
"GitHub.copilot-chat",
"ms-python.autopep8"
]
}
},

View file

@ -1,12 +1,12 @@
repos:
- repo: local
hooks:
# - id: mypy
# name: mypy
# entry: python3 -m mypy --ignore-missing-imports
# language: system
# types: [python]
# files: ^litellm/
- id: mypy
name: mypy
entry: python3 -m mypy --ignore-missing-imports
language: system
types: [python]
files: ^litellm/
- id: isort
name: isort
entry: isort

View file

@ -208,6 +208,14 @@ class LangFuseLogger:
):
input = prompt
output = response_obj["text"]
elif (
kwargs.get("call_type") is not None
and kwargs.get("call_type") == "pass_through_endpoint"
and response_obj is not None
and isinstance(response_obj, dict)
):
input = prompt
output = response_obj.get("response", "")
print_verbose(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}")
trace_id = None
generation_id = None

View file

@ -101,12 +101,6 @@ class S3Logger:
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
messages = kwargs.get("messages")
optional_params = kwargs.get("optional_params", {})
call_type = kwargs.get("call_type", "litellm.completion")
cache_hit = kwargs.get("cache_hit", False)
usage = response_obj["usage"]
id = response_obj.get("id", str(uuid.uuid4()))
# Clean Metadata before logging - never log raw metadata
# the raw metadata can contain circular references which leads to infinite recursion
@ -171,5 +165,5 @@ class S3Logger:
print_verbose(f"s3 Layer Logging - final response object: {response_obj}")
return response
except Exception as e:
verbose_logger.debug(f"s3 Layer Error - {str(e)}\n{traceback.format_exc()}")
verbose_logger.exception(f"s3 Layer Error - {str(e)}")
pass

View file

@ -41,6 +41,7 @@ from litellm.types.utils import (
StandardLoggingMetadata,
StandardLoggingModelInformation,
StandardLoggingPayload,
StandardPassThroughResponseObject,
TextCompletionResponse,
TranscriptionResponse,
)
@ -534,7 +535,9 @@ class Logging:
"""
## RESPONSE COST ##
custom_pricing = use_custom_pricing_for_model(
litellm_params=self.litellm_params
litellm_params=(
self.litellm_params if hasattr(self, "litellm_params") else None
)
)
if cache_hit is None:
@ -611,6 +614,17 @@ class Logging:
] = result._hidden_params
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=result,
start_time=start_time,
end_time=end_time,
logging_obj=self,
)
)
elif isinstance(result, dict): # pass-through endpoints
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
@ -2271,6 +2285,8 @@ def get_standard_logging_object_payload(
elif isinstance(init_response_obj, BaseModel):
response_obj = init_response_obj.model_dump()
hidden_params = getattr(init_response_obj, "_hidden_params", None)
elif isinstance(init_response_obj, dict):
response_obj = init_response_obj
else:
response_obj = {}
# standardize this function to be used across, s3, dynamoDB, langfuse logging

View file

@ -674,7 +674,7 @@ async def make_call(
timeout: Optional[Union[float, httpx.Timeout]],
):
if client is None:
client = _get_async_httpx_client() # Create a new client if none provided
client = litellm.module_level_aclient
try:
response = await client.post(
@ -690,11 +690,6 @@ async def make_call(
raise e
raise AnthropicError(status_code=500, message=str(e))
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=await response.aread()
)
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_lines(), sync_stream=False
)
@ -721,7 +716,7 @@ def make_sync_call(
timeout: Optional[Union[float, httpx.Timeout]],
):
if client is None:
client = HTTPHandler() # Create a new client if none provided
client = litellm.module_level_client # re-use a module level client
try:
response = client.post(
@ -869,6 +864,7 @@ class AnthropicChatCompletion(BaseLLM):
model_response: ModelResponse,
print_verbose: Callable,
timeout: Union[float, httpx.Timeout],
client: Optional[AsyncHTTPHandler],
encoding,
api_key,
logging_obj,
@ -882,19 +878,18 @@ class AnthropicChatCompletion(BaseLLM):
):
data["stream"] = True
completion_stream = await make_call(
client=client,
api_base=api_base,
headers=headers,
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
)
streamwrapper = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_call,
client=None,
api_base=api_base,
headers=headers,
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
),
completion_stream=completion_stream,
model=model,
custom_llm_provider="anthropic",
logging_obj=logging_obj,
@ -1080,6 +1075,11 @@ class AnthropicChatCompletion(BaseLLM):
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=(
client
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
)
else:
return self.acompletion_function(
@ -1105,33 +1105,32 @@ class AnthropicChatCompletion(BaseLLM):
)
else:
## COMPLETION CALL
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(timeout=timeout) # type: ignore
else:
client = client
if (
stream is True
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
data["stream"] = stream
completion_stream = make_sync_call(
client=client,
api_base=api_base,
headers=headers, # type: ignore
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
)
return CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_sync_call,
client=None,
api_base=api_base,
headers=headers, # type: ignore
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
),
completion_stream=completion_stream,
model=model,
custom_llm_provider="anthropic",
logging_obj=logging_obj,
)
else:
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(timeout=timeout) # type: ignore
else:
client = client
response = client.post(
api_base, headers=headers, data=json.dumps(data), timeout=timeout
)

View file

@ -110,7 +110,7 @@ class BedrockEmbedding(BaseAWSLLM):
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=response.text)
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")

View file

@ -37,16 +37,10 @@ class AsyncHTTPHandler:
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
# /path/to/certificate.pem
ssl_verify = os.getenv(
"SSL_VERIFY",
litellm.ssl_verify
)
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
# An SSL certificate used by the requested host to authenticate the client.
# /path/to/client.pem
cert = os.getenv(
"SSL_CERTIFICATE",
litellm.ssl_certificate
)
cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate)
if timeout is None:
timeout = _DEFAULT_TIMEOUT
@ -277,16 +271,10 @@ class HTTPHandler:
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
# /path/to/certificate.pem
ssl_verify = os.getenv(
"SSL_VERIFY",
litellm.ssl_verify
)
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
# An SSL certificate used by the requested host to authenticate the client.
# /path/to/client.pem
cert = os.getenv(
"SSL_CERTIFICATE",
litellm.ssl_certificate
)
cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate)
if client is None:
# Create a client with a connection pool
@ -334,6 +322,7 @@ class HTTPHandler:
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
)
response = self.client.send(req, stream=stream)
response.raise_for_status()
return response
except httpx.TimeoutException:
raise litellm.Timeout(
@ -341,6 +330,13 @@ class HTTPHandler:
model="default-model-name",
llm_provider="litellm-httpx-handler",
)
except httpx.HTTPStatusError as e:
setattr(e, "status_code", e.response.status_code)
if stream is True:
setattr(e, "message", e.response.read())
else:
setattr(e, "message", e.response.text)
raise e
except Exception as e:
raise e
@ -375,7 +371,6 @@ class HTTPHandler:
except Exception as e:
raise e
def __del__(self) -> None:
try:
self.close()
@ -437,4 +432,4 @@ def _get_httpx_client(params: Optional[dict] = None) -> HTTPHandler:
_new_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
litellm.in_memory_llm_clients_cache[_cache_key_name] = _new_client
return _new_client
return _new_client

View file

@ -534,6 +534,15 @@ def mock_completion(
model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
)
elif isinstance(mock_response, str) and mock_response.startswith(
"Exception: mock_streaming_error"
):
mock_response = litellm.MockException(
message="This is a mock error raised mid-stream",
llm_provider="anthropic",
model=model,
status_code=529,
)
time_delay = kwargs.get("mock_delay", None)
if time_delay is not None:
time.sleep(time_delay)
@ -561,6 +570,8 @@ def mock_completion(
custom_llm_provider="openai",
logging_obj=logging,
)
if isinstance(mock_response, litellm.MockException):
raise mock_response
if n is None:
model_response.choices[0].message.content = mock_response # type: ignore
else:

View file

@ -3408,6 +3408,15 @@
"litellm_provider": "bedrock",
"mode": "chat"
},
"amazon.titan-text-premier-v1:0": {
"max_tokens": 32000,
"max_input_tokens": 42000,
"max_output_tokens": 32000,
"input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000015,
"litellm_provider": "bedrock",
"mode": "chat"
},
"amazon.titan-embed-text-v1": {
"max_tokens": 8192,
"max_input_tokens": 8192,

View file

@ -1,7 +1,16 @@
model_list:
- model_name: "*"
- model_name: "anthropic/claude-3-5-sonnet-20240620"
litellm_params:
model: anthropic/claude-3-5-sonnet-20240620
# api_base: http://0.0.0.0:9000
- model_name: gpt-3.5-turbo
litellm_params:
model: openai/*
litellm_settings:
default_max_internal_user_budget: 2
success_callback: ["s3"]
s3_callback_params:
s3_bucket_name: litellm-logs # AWS Bucket Name for S3
s3_region_name: us-west-2 # AWS Region Name for S3
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3

View file

@ -872,6 +872,17 @@ class TeamMemberDeleteRequest(LiteLLMBase):
return values
class TeamMemberUpdateRequest(TeamMemberDeleteRequest):
max_budget_in_team: float
class TeamMemberUpdateResponse(LiteLLMBase):
team_id: str
user_id: str
user_email: Optional[str] = None
max_budget_in_team: float
class UpdateTeamRequest(LiteLLMBase):
"""
UpdateTeamRequest, used by /team/update when you need to update a team
@ -1854,3 +1865,10 @@ class LiteLLM_TeamMembership(LiteLLMBase):
class TeamAddMemberResponse(LiteLLM_TeamTable):
updated_users: List[LiteLLM_UserTable]
updated_team_memberships: List[LiteLLM_TeamMembership]
class TeamInfoResponseObject(TypedDict):
team_id: str
team_info: TeamBase
keys: List
team_memberships: List[LiteLLM_TeamMembership]

View file

@ -28,8 +28,11 @@ from litellm.proxy._types import (
ProxyErrorTypes,
ProxyException,
TeamAddMemberResponse,
TeamInfoResponseObject,
TeamMemberAddRequest,
TeamMemberDeleteRequest,
TeamMemberUpdateRequest,
TeamMemberUpdateResponse,
UpdateTeamRequest,
UserAPIKeyAuth,
)
@ -750,6 +753,131 @@ async def team_member_delete(
return existing_team_row
@router.post(
"/team/member_update",
tags=["team management"],
dependencies=[Depends(user_api_key_auth)],
response_model=TeamMemberUpdateResponse,
)
@management_endpoint_wrapper
async def team_member_update(
data: TeamMemberUpdateRequest,
http_request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
[BETA]
Update team member budgets
"""
from litellm.proxy.proxy_server import (
_duration_in_seconds,
create_audit_log_for_update,
litellm_proxy_admin_name,
prisma_client,
)
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
if data.team_id is None:
raise HTTPException(status_code=400, detail={"error": "No team id passed in"})
if data.user_id is None and data.user_email is None:
raise HTTPException(
status_code=400,
detail={"error": "Either user_id or user_email needs to be passed in"},
)
_existing_team_row = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": data.team_id}
)
if _existing_team_row is None:
raise HTTPException(
status_code=400,
detail={"error": "Team id={} does not exist in db".format(data.team_id)},
)
existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump())
## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN
if (
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
and not _is_user_team_admin(
user_api_key_dict=user_api_key_dict, team_obj=existing_team_row
)
):
raise HTTPException(
status_code=403,
detail={
"error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format(
"/team/member_delete", existing_team_row.team_id
)
},
)
returned_team_info: TeamInfoResponseObject = await team_info(
http_request=http_request,
team_id=data.team_id,
user_api_key_dict=user_api_key_dict,
)
## get user id
received_user_id: Optional[str] = None
if data.user_id is not None:
received_user_id = data.user_id
elif data.user_email is not None:
for member in returned_team_info["team_info"].members_with_roles:
if member.user_email is not None and member.user_email == data.user_email:
received_user_id = member.user_id
break
if received_user_id is None:
raise HTTPException(
status_code=400,
detail={
"error": "User id doesn't exist in team table. Data={}".format(data)
},
)
## find the relevant team membership
identified_budget_id: Optional[str] = None
for tm in returned_team_info["team_memberships"]:
if tm.user_id == received_user_id:
identified_budget_id = tm.budget_id
break
### upsert new budget
if identified_budget_id is None:
new_budget = await prisma_client.db.litellm_budgettable.create(
data={
"max_budget": data.max_budget_in_team,
"created_by": user_api_key_dict.user_id or "",
"updated_by": user_api_key_dict.user_id or "",
}
)
await prisma_client.db.litellm_teammembership.create(
data={
"team_id": data.team_id,
"user_id": received_user_id,
"budget_id": new_budget.budget_id,
},
)
else:
await prisma_client.db.litellm_budgettable.update(
where={"budget_id": identified_budget_id},
data={"max_budget": data.max_budget_in_team},
)
return TeamMemberUpdateResponse(
team_id=data.team_id,
user_id=received_user_id,
user_email=data.user_email,
max_budget_in_team=data.max_budget_in_team,
)
@router.post(
"/team/delete", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
)
@ -937,12 +1065,18 @@ async def team_info(
where={"team_id": team_id},
include={"litellm_budget_table": True},
)
return {
"team_id": team_id,
"team_info": team_info,
"keys": keys,
"team_memberships": team_memberships,
}
returned_tm: List[LiteLLM_TeamMembership] = []
for tm in team_memberships:
returned_tm.append(LiteLLM_TeamMembership(**tm.model_dump()))
response_object = TeamInfoResponseObject(
team_id=team_id,
team_info=team_info,
keys=keys,
team_memberships=returned_tm,
)
return response_object
except Exception as e:
verbose_proxy_logger.error(

View file

@ -359,7 +359,7 @@ async def pass_through_request(
start_time = datetime.now()
logging_obj = Logging(
model="unknown",
messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
messages=[{"role": "user", "content": json.dumps(_parsed_body)}],
stream=False,
call_type="pass_through_endpoint",
start_time=start_time,
@ -414,7 +414,7 @@ async def pass_through_request(
logging_url = str(url) + "?" + requested_query_params_str
logging_obj.pre_call(
input=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
input=[{"role": "user", "content": json.dumps(_parsed_body)}],
api_key="",
additional_args={
"complete_input_dict": _parsed_body,

View file

@ -1,4 +1,6 @@
import json
import re
import threading
from datetime import datetime
from typing import Union
@ -10,6 +12,7 @@ from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_stu
VertexLLM,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.types.utils import StandardPassThroughResponseObject
class PassThroughEndpointLogging:
@ -43,8 +46,24 @@ class PassThroughEndpointLogging:
**kwargs,
)
else:
standard_logging_response_object = StandardPassThroughResponseObject(
response=httpx_response.text
)
threading.Thread(
target=logging_obj.success_handler,
args=(
standard_logging_response_object,
start_time,
end_time,
cache_hit,
),
).start()
await logging_obj.async_success_handler(
result="",
result=(
json.dumps(result)
if isinstance(result, dict)
else standard_logging_response_object
),
start_time=start_time,
end_time=end_time,
cache_hit=False,

View file

@ -3005,13 +3005,13 @@ def model_list(
This is just for compatibility with openai projects like aider.
"""
global llm_model_list, general_settings
global llm_model_list, general_settings, llm_router
all_models = []
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
if llm_model_list is None:
if llm_router is None:
proxy_model_list = []
else:
proxy_model_list = [m["model_name"] for m in llm_model_list]
proxy_model_list = llm_router.get_model_names()
key_models = get_key_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
)
@ -7503,10 +7503,11 @@ async def model_info_v1(
all_models: List[dict] = []
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
if llm_model_list is None:
if llm_router is None:
proxy_model_list = []
else:
proxy_model_list = [m["model_name"] for m in llm_model_list]
proxy_model_list = llm_router.get_model_names()
key_models = get_key_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
)
@ -7523,8 +7524,14 @@ async def model_info_v1(
if len(all_models_str) > 0:
model_names = all_models_str
_relevant_models = [m for m in llm_model_list if m["model_name"] in model_names]
all_models = copy.deepcopy(_relevant_models)
llm_model_list = llm_router.get_model_list()
if llm_model_list is not None:
_relevant_models = [
m for m in llm_model_list if m["model_name"] in model_names
]
all_models = copy.deepcopy(_relevant_models) # type: ignore
else:
all_models = []
for model in all_models:
# provided model_info in config.yaml
@ -7590,12 +7597,12 @@ async def model_group_info(
raise HTTPException(
status_code=500, detail={"error": "LLM Router is not loaded in"}
)
all_models: List[dict] = []
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
if llm_model_list is None:
if llm_router is None:
proxy_model_list = []
else:
proxy_model_list = [m["model_name"] for m in llm_model_list]
proxy_model_list = llm_router.get_model_names()
key_models = get_key_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
)

View file

@ -86,6 +86,7 @@ from litellm.types.router import (
Deployment,
DeploymentTypedDict,
LiteLLM_Params,
LiteLLMParamsTypedDict,
ModelGroupInfo,
ModelInfo,
RetryPolicy,
@ -4297,7 +4298,9 @@ class Router:
return model
return None
def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:
def _set_model_group_info(
self, model_group: str, user_facing_model_group_name: str
) -> Optional[ModelGroupInfo]:
"""
For a given model group name, return the combined model info
@ -4379,7 +4382,7 @@ class Router:
if model_group_info is None:
model_group_info = ModelGroupInfo(
model_group=model_group, providers=[llm_provider], **model_info # type: ignore
model_group=user_facing_model_group_name, providers=[llm_provider], **model_info # type: ignore
)
else:
# if max_input_tokens > curr
@ -4464,6 +4467,26 @@ class Router:
return model_group_info
def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:
"""
For a given model group name, return the combined model info
Returns:
- ModelGroupInfo if able to construct a model group
- None if error constructing model group info
"""
## Check if model group alias
if model_group in self.model_group_alias:
return self._set_model_group_info(
model_group=self.model_group_alias[model_group],
user_facing_model_group_name=model_group,
)
## Check if actual model
return self._set_model_group_info(
model_group=model_group, user_facing_model_group_name=model_group
)
async def get_model_group_usage(
self, model_group: str
) -> Tuple[Optional[int], Optional[int]]:
@ -4534,19 +4557,35 @@ class Router:
return ids
def get_model_names(self) -> List[str]:
return self.model_names
"""
Returns all possible model names for router.
Includes model_group_alias models too.
"""
return self.model_names + list(self.model_group_alias.keys())
def get_model_list(
self, model_name: Optional[str] = None
) -> Optional[List[DeploymentTypedDict]]:
if hasattr(self, "model_list"):
if model_name is None:
return self.model_list
returned_models: List[DeploymentTypedDict] = []
for model_alias, model_value in self.model_group_alias.items():
model_alias_item = DeploymentTypedDict(
model_name=model_alias,
litellm_params=LiteLLMParamsTypedDict(model=model_value),
)
returned_models.append(model_alias_item)
if model_name is None:
returned_models += self.model_list
return returned_models
for model in self.model_list:
if model["model_name"] == model_name:
returned_models.append(model)
return returned_models
return None

View file

@ -22,7 +22,7 @@ import litellm
from litellm import create_batch, create_file
@pytest.mark.parametrize("provider", ["openai", "azure"])
@pytest.mark.parametrize("provider", ["openai"]) # , "azure"
def test_create_batch(provider):
"""
1. Create File for Batch completion
@ -96,7 +96,7 @@ def test_create_batch(provider):
pass
@pytest.mark.parametrize("provider", ["openai", "azure"])
@pytest.mark.parametrize("provider", ["openai"]) # "azure"
@pytest.mark.asyncio()
@pytest.mark.flaky(retries=3, delay=1)
async def test_async_create_batch(provider):

View file

@ -2396,6 +2396,7 @@ async def test_router_weighted_pick(sync_mode):
assert model_id_1_count > model_id_2_count
@pytest.mark.skip(reason="Hit azure batch quota limits")
@pytest.mark.parametrize("provider", ["azure"])
@pytest.mark.asyncio
async def test_router_batch_endpoints(provider):

View file

@ -12,6 +12,7 @@ import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from unittest.mock import AsyncMock, MagicMock, patch
import litellm
from litellm import Router
@ -1266,3 +1267,73 @@ async def test_using_default_working_fallback(sync_mode):
)
print("got response=", response)
assert response is not None
# asyncio.run(test_acompletion_gemini_stream())
def mock_post_streaming(url, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 529
mock_response.headers = {"Content-Type": "application/json"}
mock_response.return_value = {"detail": "Overloaded!"}
return mock_response
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_anthropic_streaming_fallbacks(sync_mode):
litellm.set_verbose = True
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
if sync_mode:
client = HTTPHandler(concurrent_limit=1)
else:
client = AsyncHTTPHandler(concurrent_limit=1)
router = Router(
model_list=[
{
"model_name": "anthropic/claude-3-5-sonnet-20240620",
"litellm_params": {
"model": "anthropic/claude-3-5-sonnet-20240620",
},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"mock_response": "Hey, how's it going?",
},
},
],
fallbacks=[{"anthropic/claude-3-5-sonnet-20240620": ["gpt-3.5-turbo"]}],
num_retries=0,
)
with patch.object(client, "post", side_effect=mock_post_streaming) as mock_client:
chunks = []
if sync_mode:
response = router.completion(
model="anthropic/claude-3-5-sonnet-20240620",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
stream=True,
client=client,
)
for chunk in response:
print(chunk)
chunks.append(chunk)
else:
response = await router.acompletion(
model="anthropic/claude-3-5-sonnet-20240620",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
stream=True,
client=client,
)
async for chunk in response:
print(chunk)
chunks.append(chunk)
print(f"RETURNED response: {response}")
mock_client.assert_called_once()
print(chunks)
assert len(chunks) > 0

View file

@ -127,6 +127,7 @@ async def test_completion_sagemaker_messages_api(sync_mode):
"sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
],
)
@pytest.mark.flaky(retries=3, delay=1)
async def test_completion_sagemaker_stream(sync_mode, model):
try:
from litellm.tests.test_streaming import streaming_format_tests

View file

@ -3144,7 +3144,6 @@ async def test_azure_astreaming_and_function_calling():
def test_completion_claude_3_function_call_with_streaming():
litellm.set_verbose = True
tools = [
{
"type": "function",
@ -3827,6 +3826,65 @@ def test_unit_test_custom_stream_wrapper_function_call():
assert len(new_model.choices[0].delta.tool_calls) > 0
def test_unit_test_perplexity_citations_chunk():
"""
Test if model returns a tool call, the finish reason is correctly set to 'tool_calls'
"""
from litellm.types.llms.openai import ChatCompletionDeltaChunk
litellm.set_verbose = False
delta: ChatCompletionDeltaChunk = {
"content": "B",
"role": "assistant",
}
chunk = {
"id": "xxx",
"model": "llama-3.1-sonar-small-128k-online",
"created": 1725494279,
"usage": {"prompt_tokens": 15, "completion_tokens": 1, "total_tokens": 16},
"citations": [
"https://x.com/bizzabo?lang=ur",
"https://apps.apple.com/my/app/bizzabo/id408705047",
"https://www.bizzabo.com/blog/maximize-event-data-strategies-for-success",
],
"object": "chat.completion",
"choices": [
{
"index": 0,
"finish_reason": None,
"message": {"role": "assistant", "content": "B"},
"delta": delta,
}
],
}
chunk = litellm.ModelResponse(**chunk, stream=True)
completion_stream = ModelResponseIterator(model_response=chunk)
response = litellm.CustomStreamWrapper(
completion_stream=completion_stream,
model="gpt-3.5-turbo",
custom_llm_provider="cached_response",
logging_obj=litellm.litellm_core_utils.litellm_logging.Logging(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey"}],
stream=True,
call_type="completion",
start_time=time.time(),
litellm_call_id="12345",
function_id="1245",
),
)
finish_reason: Optional[str] = None
for response_chunk in response:
if response_chunk.choices[0].delta.content is not None:
print(
f"response_chunk.choices[0].delta.content: {response_chunk.choices[0].delta.content}"
)
assert "citations" in response_chunk
@pytest.mark.parametrize(
"model",
[

View file

@ -1297,3 +1297,7 @@ class CustomStreamingDecoder:
self, iterator: Iterator[bytes]
) -> Iterator[Optional[Union[GenericStreamingChunk, StreamingChatCompletionChunk]]]:
raise NotImplementedError
class StandardPassThroughResponseObject(TypedDict):
response: str

View file

@ -9845,6 +9845,9 @@ class CustomStreamWrapper:
model_response.system_fingerprint = (
original_chunk.system_fingerprint
)
model_response.citations = getattr(
original_chunk, "citations", None
)
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
if self.sent_first_chunk is False:
model_response.choices[0].delta["role"] = "assistant"
@ -10460,6 +10463,8 @@ class TextCompletionStreamWrapper:
def mock_completion_streaming_obj(
model_response, mock_response, model, n: Optional[int] = None
):
if isinstance(mock_response, litellm.MockException):
raise mock_response
for i in range(0, len(mock_response), 3):
completion_obj = Delta(role="assistant", content=mock_response[i : i + 3])
if n is None:
@ -10481,6 +10486,8 @@ def mock_completion_streaming_obj(
async def async_mock_completion_streaming_obj(
model_response, mock_response, model, n: Optional[int] = None
):
if isinstance(mock_response, litellm.MockException):
raise mock_response
for i in range(0, len(mock_response), 3):
completion_obj = Delta(role="assistant", content=mock_response[i : i + 3])
if n is None:

View file

@ -3408,6 +3408,15 @@
"litellm_provider": "bedrock",
"mode": "chat"
},
"amazon.titan-text-premier-v1:0": {
"max_tokens": 32000,
"max_input_tokens": 42000,
"max_output_tokens": 32000,
"input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000015,
"litellm_provider": "bedrock",
"mode": "chat"
},
"amazon.titan-embed-text-v1": {
"max_tokens": 8192,
"max_input_tokens": 8192,