mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
LiteLLM Minor Fixes and Improvements (08/06/2024) (#5567)
* fix(utils.py): return citations for perplexity streaming Fixes https://github.com/BerriAI/litellm/issues/5535 * fix(anthropic/chat.py): support fallbacks for anthropic streaming (#5542) * fix(anthropic/chat.py): support fallbacks for anthropic streaming Fixes https://github.com/BerriAI/litellm/issues/5512 * fix(anthropic/chat.py): use module level http client if none given (prevents early client closure) * fix: fix linting errors * fix(http_handler.py): fix raise_for_status error handling * test: retry flaky test * fix otel type * fix(bedrock/embed): fix error raising * test(test_openai_batches_and_files.py): skip azure batches test (for now) quota exceeded * fix(test_router.py): skip azure batch route test (for now) - hit batch quota limits --------- Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> * All `model_group_alias` should show up in `/models`, `/model/info` , `/model_group/info` (#5539) * fix(router.py): support returning model_alias model names in `/v1/models` * fix(proxy_server.py): support returning model alias'es on `/model/info` * feat(router.py): support returning model group alias for `/model_group/info` * fix(proxy_server.py): fix linting errors * fix(proxy_server.py): fix linting errors * build(model_prices_and_context_window.json): add amazon titan text premier pricing information Closes https://github.com/BerriAI/litellm/issues/5560 * feat(litellm_logging.py): log standard logging response object for pass through endpoints. Allows bedrock /invoke agent calls to be correctly logged to langfuse + s3 * fix(success_handler.py): fix linting error * fix(success_handler.py): fix linting errors * fix(team_endpoints.py): Allows admin to update team member budgets --------- Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
This commit is contained in:
parent
e4dcd6f745
commit
72e961af3c
25 changed files with 509 additions and 99 deletions
|
@ -22,7 +22,8 @@
|
|||
"ms-python.python",
|
||||
"ms-python.vscode-pylance",
|
||||
"GitHub.copilot",
|
||||
"GitHub.copilot-chat"
|
||||
"GitHub.copilot-chat",
|
||||
"ms-python.autopep8"
|
||||
]
|
||||
}
|
||||
},
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
# - id: mypy
|
||||
# name: mypy
|
||||
# entry: python3 -m mypy --ignore-missing-imports
|
||||
# language: system
|
||||
# types: [python]
|
||||
# files: ^litellm/
|
||||
- id: mypy
|
||||
name: mypy
|
||||
entry: python3 -m mypy --ignore-missing-imports
|
||||
language: system
|
||||
types: [python]
|
||||
files: ^litellm/
|
||||
- id: isort
|
||||
name: isort
|
||||
entry: isort
|
||||
|
|
|
@ -208,6 +208,14 @@ class LangFuseLogger:
|
|||
):
|
||||
input = prompt
|
||||
output = response_obj["text"]
|
||||
elif (
|
||||
kwargs.get("call_type") is not None
|
||||
and kwargs.get("call_type") == "pass_through_endpoint"
|
||||
and response_obj is not None
|
||||
and isinstance(response_obj, dict)
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj.get("response", "")
|
||||
print_verbose(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}")
|
||||
trace_id = None
|
||||
generation_id = None
|
||||
|
|
|
@ -101,12 +101,6 @@ class S3Logger:
|
|||
metadata = (
|
||||
litellm_params.get("metadata", {}) or {}
|
||||
) # if litellm_params['metadata'] == None
|
||||
messages = kwargs.get("messages")
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
call_type = kwargs.get("call_type", "litellm.completion")
|
||||
cache_hit = kwargs.get("cache_hit", False)
|
||||
usage = response_obj["usage"]
|
||||
id = response_obj.get("id", str(uuid.uuid4()))
|
||||
|
||||
# Clean Metadata before logging - never log raw metadata
|
||||
# the raw metadata can contain circular references which leads to infinite recursion
|
||||
|
@ -171,5 +165,5 @@ class S3Logger:
|
|||
print_verbose(f"s3 Layer Logging - final response object: {response_obj}")
|
||||
return response
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"s3 Layer Error - {str(e)}\n{traceback.format_exc()}")
|
||||
verbose_logger.exception(f"s3 Layer Error - {str(e)}")
|
||||
pass
|
||||
|
|
|
@ -41,6 +41,7 @@ from litellm.types.utils import (
|
|||
StandardLoggingMetadata,
|
||||
StandardLoggingModelInformation,
|
||||
StandardLoggingPayload,
|
||||
StandardPassThroughResponseObject,
|
||||
TextCompletionResponse,
|
||||
TranscriptionResponse,
|
||||
)
|
||||
|
@ -534,7 +535,9 @@ class Logging:
|
|||
"""
|
||||
## RESPONSE COST ##
|
||||
custom_pricing = use_custom_pricing_for_model(
|
||||
litellm_params=self.litellm_params
|
||||
litellm_params=(
|
||||
self.litellm_params if hasattr(self, "litellm_params") else None
|
||||
)
|
||||
)
|
||||
|
||||
if cache_hit is None:
|
||||
|
@ -611,6 +614,17 @@ class Logging:
|
|||
] = result._hidden_params
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
)
|
||||
)
|
||||
elif isinstance(result, dict): # pass-through endpoints
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
|
@ -2271,6 +2285,8 @@ def get_standard_logging_object_payload(
|
|||
elif isinstance(init_response_obj, BaseModel):
|
||||
response_obj = init_response_obj.model_dump()
|
||||
hidden_params = getattr(init_response_obj, "_hidden_params", None)
|
||||
elif isinstance(init_response_obj, dict):
|
||||
response_obj = init_response_obj
|
||||
else:
|
||||
response_obj = {}
|
||||
# standardize this function to be used across, s3, dynamoDB, langfuse logging
|
||||
|
|
|
@ -674,7 +674,7 @@ async def make_call(
|
|||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
):
|
||||
if client is None:
|
||||
client = _get_async_httpx_client() # Create a new client if none provided
|
||||
client = litellm.module_level_aclient
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
|
@ -690,11 +690,6 @@ async def make_call(
|
|||
raise e
|
||||
raise AnthropicError(status_code=500, message=str(e))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code, message=await response.aread()
|
||||
)
|
||||
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.aiter_lines(), sync_stream=False
|
||||
)
|
||||
|
@ -721,7 +716,7 @@ def make_sync_call(
|
|||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
):
|
||||
if client is None:
|
||||
client = HTTPHandler() # Create a new client if none provided
|
||||
client = litellm.module_level_client # re-use a module level client
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
|
@ -869,6 +864,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
client: Optional[AsyncHTTPHandler],
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
|
@ -882,19 +878,18 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
):
|
||||
data["stream"] = True
|
||||
|
||||
completion_stream = await make_call(
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_call,
|
||||
client=None,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
),
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="anthropic",
|
||||
logging_obj=logging_obj,
|
||||
|
@ -1080,6 +1075,11 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||
else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
return self.acompletion_function(
|
||||
|
@ -1105,33 +1105,32 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
)
|
||||
else:
|
||||
## COMPLETION CALL
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = HTTPHandler(timeout=timeout) # type: ignore
|
||||
else:
|
||||
client = client
|
||||
if (
|
||||
stream is True
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
data["stream"] = stream
|
||||
completion_stream = make_sync_call(
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers, # type: ignore
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_sync_call,
|
||||
client=None,
|
||||
api_base=api_base,
|
||||
headers=headers, # type: ignore
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
),
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="anthropic",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
else:
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = HTTPHandler(timeout=timeout) # type: ignore
|
||||
else:
|
||||
client = client
|
||||
response = client.post(
|
||||
api_base, headers=headers, data=json.dumps(data), timeout=timeout
|
||||
)
|
||||
|
|
|
@ -110,7 +110,7 @@ class BedrockEmbedding(BaseAWSLLM):
|
|||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=response.text)
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
|
|
|
@ -37,16 +37,10 @@ class AsyncHTTPHandler:
|
|||
|
||||
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
||||
# /path/to/certificate.pem
|
||||
ssl_verify = os.getenv(
|
||||
"SSL_VERIFY",
|
||||
litellm.ssl_verify
|
||||
)
|
||||
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
|
||||
# An SSL certificate used by the requested host to authenticate the client.
|
||||
# /path/to/client.pem
|
||||
cert = os.getenv(
|
||||
"SSL_CERTIFICATE",
|
||||
litellm.ssl_certificate
|
||||
)
|
||||
cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate)
|
||||
|
||||
if timeout is None:
|
||||
timeout = _DEFAULT_TIMEOUT
|
||||
|
@ -277,16 +271,10 @@ class HTTPHandler:
|
|||
|
||||
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
||||
# /path/to/certificate.pem
|
||||
ssl_verify = os.getenv(
|
||||
"SSL_VERIFY",
|
||||
litellm.ssl_verify
|
||||
)
|
||||
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
|
||||
# An SSL certificate used by the requested host to authenticate the client.
|
||||
# /path/to/client.pem
|
||||
cert = os.getenv(
|
||||
"SSL_CERTIFICATE",
|
||||
litellm.ssl_certificate
|
||||
)
|
||||
cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate)
|
||||
|
||||
if client is None:
|
||||
# Create a client with a connection pool
|
||||
|
@ -334,6 +322,7 @@ class HTTPHandler:
|
|||
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
|
||||
)
|
||||
response = self.client.send(req, stream=stream)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
except httpx.TimeoutException:
|
||||
raise litellm.Timeout(
|
||||
|
@ -341,6 +330,13 @@ class HTTPHandler:
|
|||
model="default-model-name",
|
||||
llm_provider="litellm-httpx-handler",
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
setattr(e, "status_code", e.response.status_code)
|
||||
if stream is True:
|
||||
setattr(e, "message", e.response.read())
|
||||
else:
|
||||
setattr(e, "message", e.response.text)
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -375,7 +371,6 @@ class HTTPHandler:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
self.close()
|
||||
|
@ -437,4 +432,4 @@ def _get_httpx_client(params: Optional[dict] = None) -> HTTPHandler:
|
|||
_new_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
|
||||
litellm.in_memory_llm_clients_cache[_cache_key_name] = _new_client
|
||||
return _new_client
|
||||
return _new_client
|
||||
|
|
|
@ -534,6 +534,15 @@ def mock_completion(
|
|||
model=model, # type: ignore
|
||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||
)
|
||||
elif isinstance(mock_response, str) and mock_response.startswith(
|
||||
"Exception: mock_streaming_error"
|
||||
):
|
||||
mock_response = litellm.MockException(
|
||||
message="This is a mock error raised mid-stream",
|
||||
llm_provider="anthropic",
|
||||
model=model,
|
||||
status_code=529,
|
||||
)
|
||||
time_delay = kwargs.get("mock_delay", None)
|
||||
if time_delay is not None:
|
||||
time.sleep(time_delay)
|
||||
|
@ -561,6 +570,8 @@ def mock_completion(
|
|||
custom_llm_provider="openai",
|
||||
logging_obj=logging,
|
||||
)
|
||||
if isinstance(mock_response, litellm.MockException):
|
||||
raise mock_response
|
||||
if n is None:
|
||||
model_response.choices[0].message.content = mock_response # type: ignore
|
||||
else:
|
||||
|
|
|
@ -3408,6 +3408,15 @@
|
|||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"amazon.titan-text-premier-v1:0": {
|
||||
"max_tokens": 32000,
|
||||
"max_input_tokens": 42000,
|
||||
"max_output_tokens": 32000,
|
||||
"input_cost_per_token": 0.0000005,
|
||||
"output_cost_per_token": 0.0000015,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"amazon.titan-embed-text-v1": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 8192,
|
||||
|
|
|
@ -1,7 +1,16 @@
|
|||
model_list:
|
||||
- model_name: "*"
|
||||
- model_name: "anthropic/claude-3-5-sonnet-20240620"
|
||||
litellm_params:
|
||||
model: anthropic/claude-3-5-sonnet-20240620
|
||||
# api_base: http://0.0.0.0:9000
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: openai/*
|
||||
|
||||
litellm_settings:
|
||||
default_max_internal_user_budget: 2
|
||||
success_callback: ["s3"]
|
||||
s3_callback_params:
|
||||
s3_bucket_name: litellm-logs # AWS Bucket Name for S3
|
||||
s3_region_name: us-west-2 # AWS Region Name for S3
|
||||
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
|
||||
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
|
|
@ -872,6 +872,17 @@ class TeamMemberDeleteRequest(LiteLLMBase):
|
|||
return values
|
||||
|
||||
|
||||
class TeamMemberUpdateRequest(TeamMemberDeleteRequest):
|
||||
max_budget_in_team: float
|
||||
|
||||
|
||||
class TeamMemberUpdateResponse(LiteLLMBase):
|
||||
team_id: str
|
||||
user_id: str
|
||||
user_email: Optional[str] = None
|
||||
max_budget_in_team: float
|
||||
|
||||
|
||||
class UpdateTeamRequest(LiteLLMBase):
|
||||
"""
|
||||
UpdateTeamRequest, used by /team/update when you need to update a team
|
||||
|
@ -1854,3 +1865,10 @@ class LiteLLM_TeamMembership(LiteLLMBase):
|
|||
class TeamAddMemberResponse(LiteLLM_TeamTable):
|
||||
updated_users: List[LiteLLM_UserTable]
|
||||
updated_team_memberships: List[LiteLLM_TeamMembership]
|
||||
|
||||
|
||||
class TeamInfoResponseObject(TypedDict):
|
||||
team_id: str
|
||||
team_info: TeamBase
|
||||
keys: List
|
||||
team_memberships: List[LiteLLM_TeamMembership]
|
||||
|
|
|
@ -28,8 +28,11 @@ from litellm.proxy._types import (
|
|||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
TeamAddMemberResponse,
|
||||
TeamInfoResponseObject,
|
||||
TeamMemberAddRequest,
|
||||
TeamMemberDeleteRequest,
|
||||
TeamMemberUpdateRequest,
|
||||
TeamMemberUpdateResponse,
|
||||
UpdateTeamRequest,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
|
@ -750,6 +753,131 @@ async def team_member_delete(
|
|||
return existing_team_row
|
||||
|
||||
|
||||
@router.post(
|
||||
"/team/member_update",
|
||||
tags=["team management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=TeamMemberUpdateResponse,
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def team_member_update(
|
||||
data: TeamMemberUpdateRequest,
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[BETA]
|
||||
|
||||
Update team member budgets
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
_duration_in_seconds,
|
||||
create_audit_log_for_update,
|
||||
litellm_proxy_admin_name,
|
||||
prisma_client,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
if data.team_id is None:
|
||||
raise HTTPException(status_code=400, detail={"error": "No team id passed in"})
|
||||
|
||||
if data.user_id is None and data.user_email is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Either user_id or user_email needs to be passed in"},
|
||||
)
|
||||
|
||||
_existing_team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": data.team_id}
|
||||
)
|
||||
|
||||
if _existing_team_row is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Team id={} does not exist in db".format(data.team_id)},
|
||||
)
|
||||
existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump())
|
||||
|
||||
## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN
|
||||
|
||||
if (
|
||||
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||
and not _is_user_team_admin(
|
||||
user_api_key_dict=user_api_key_dict, team_obj=existing_team_row
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format(
|
||||
"/team/member_delete", existing_team_row.team_id
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
returned_team_info: TeamInfoResponseObject = await team_info(
|
||||
http_request=http_request,
|
||||
team_id=data.team_id,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
## get user id
|
||||
received_user_id: Optional[str] = None
|
||||
if data.user_id is not None:
|
||||
received_user_id = data.user_id
|
||||
elif data.user_email is not None:
|
||||
for member in returned_team_info["team_info"].members_with_roles:
|
||||
if member.user_email is not None and member.user_email == data.user_email:
|
||||
received_user_id = member.user_id
|
||||
break
|
||||
|
||||
if received_user_id is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "User id doesn't exist in team table. Data={}".format(data)
|
||||
},
|
||||
)
|
||||
## find the relevant team membership
|
||||
identified_budget_id: Optional[str] = None
|
||||
for tm in returned_team_info["team_memberships"]:
|
||||
if tm.user_id == received_user_id:
|
||||
identified_budget_id = tm.budget_id
|
||||
break
|
||||
|
||||
### upsert new budget
|
||||
if identified_budget_id is None:
|
||||
new_budget = await prisma_client.db.litellm_budgettable.create(
|
||||
data={
|
||||
"max_budget": data.max_budget_in_team,
|
||||
"created_by": user_api_key_dict.user_id or "",
|
||||
"updated_by": user_api_key_dict.user_id or "",
|
||||
}
|
||||
)
|
||||
|
||||
await prisma_client.db.litellm_teammembership.create(
|
||||
data={
|
||||
"team_id": data.team_id,
|
||||
"user_id": received_user_id,
|
||||
"budget_id": new_budget.budget_id,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await prisma_client.db.litellm_budgettable.update(
|
||||
where={"budget_id": identified_budget_id},
|
||||
data={"max_budget": data.max_budget_in_team},
|
||||
)
|
||||
|
||||
return TeamMemberUpdateResponse(
|
||||
team_id=data.team_id,
|
||||
user_id=received_user_id,
|
||||
user_email=data.user_email,
|
||||
max_budget_in_team=data.max_budget_in_team,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/team/delete", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
|
||||
)
|
||||
|
@ -937,12 +1065,18 @@ async def team_info(
|
|||
where={"team_id": team_id},
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
return {
|
||||
"team_id": team_id,
|
||||
"team_info": team_info,
|
||||
"keys": keys,
|
||||
"team_memberships": team_memberships,
|
||||
}
|
||||
|
||||
returned_tm: List[LiteLLM_TeamMembership] = []
|
||||
for tm in team_memberships:
|
||||
returned_tm.append(LiteLLM_TeamMembership(**tm.model_dump()))
|
||||
|
||||
response_object = TeamInfoResponseObject(
|
||||
team_id=team_id,
|
||||
team_info=team_info,
|
||||
keys=keys,
|
||||
team_memberships=returned_tm,
|
||||
)
|
||||
return response_object
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
|
|
|
@ -359,7 +359,7 @@ async def pass_through_request(
|
|||
start_time = datetime.now()
|
||||
logging_obj = Logging(
|
||||
model="unknown",
|
||||
messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
|
||||
messages=[{"role": "user", "content": json.dumps(_parsed_body)}],
|
||||
stream=False,
|
||||
call_type="pass_through_endpoint",
|
||||
start_time=start_time,
|
||||
|
@ -414,7 +414,7 @@ async def pass_through_request(
|
|||
logging_url = str(url) + "?" + requested_query_params_str
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
|
||||
input=[{"role": "user", "content": json.dumps(_parsed_body)}],
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": _parsed_body,
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import json
|
||||
import re
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import Union
|
||||
|
||||
|
@ -10,6 +12,7 @@ from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_stu
|
|||
VertexLLM,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.types.utils import StandardPassThroughResponseObject
|
||||
|
||||
|
||||
class PassThroughEndpointLogging:
|
||||
|
@ -43,8 +46,24 @@ class PassThroughEndpointLogging:
|
|||
**kwargs,
|
||||
)
|
||||
else:
|
||||
standard_logging_response_object = StandardPassThroughResponseObject(
|
||||
response=httpx_response.text
|
||||
)
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler,
|
||||
args=(
|
||||
standard_logging_response_object,
|
||||
start_time,
|
||||
end_time,
|
||||
cache_hit,
|
||||
),
|
||||
).start()
|
||||
await logging_obj.async_success_handler(
|
||||
result="",
|
||||
result=(
|
||||
json.dumps(result)
|
||||
if isinstance(result, dict)
|
||||
else standard_logging_response_object
|
||||
),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=False,
|
||||
|
|
|
@ -3005,13 +3005,13 @@ def model_list(
|
|||
|
||||
This is just for compatibility with openai projects like aider.
|
||||
"""
|
||||
global llm_model_list, general_settings
|
||||
global llm_model_list, general_settings, llm_router
|
||||
all_models = []
|
||||
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
|
||||
if llm_model_list is None:
|
||||
if llm_router is None:
|
||||
proxy_model_list = []
|
||||
else:
|
||||
proxy_model_list = [m["model_name"] for m in llm_model_list]
|
||||
proxy_model_list = llm_router.get_model_names()
|
||||
key_models = get_key_models(
|
||||
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
||||
)
|
||||
|
@ -7503,10 +7503,11 @@ async def model_info_v1(
|
|||
|
||||
all_models: List[dict] = []
|
||||
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
|
||||
if llm_model_list is None:
|
||||
if llm_router is None:
|
||||
proxy_model_list = []
|
||||
else:
|
||||
proxy_model_list = [m["model_name"] for m in llm_model_list]
|
||||
proxy_model_list = llm_router.get_model_names()
|
||||
|
||||
key_models = get_key_models(
|
||||
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
||||
)
|
||||
|
@ -7523,8 +7524,14 @@ async def model_info_v1(
|
|||
|
||||
if len(all_models_str) > 0:
|
||||
model_names = all_models_str
|
||||
_relevant_models = [m for m in llm_model_list if m["model_name"] in model_names]
|
||||
all_models = copy.deepcopy(_relevant_models)
|
||||
llm_model_list = llm_router.get_model_list()
|
||||
if llm_model_list is not None:
|
||||
_relevant_models = [
|
||||
m for m in llm_model_list if m["model_name"] in model_names
|
||||
]
|
||||
all_models = copy.deepcopy(_relevant_models) # type: ignore
|
||||
else:
|
||||
all_models = []
|
||||
|
||||
for model in all_models:
|
||||
# provided model_info in config.yaml
|
||||
|
@ -7590,12 +7597,12 @@ async def model_group_info(
|
|||
raise HTTPException(
|
||||
status_code=500, detail={"error": "LLM Router is not loaded in"}
|
||||
)
|
||||
all_models: List[dict] = []
|
||||
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
|
||||
if llm_model_list is None:
|
||||
if llm_router is None:
|
||||
proxy_model_list = []
|
||||
else:
|
||||
proxy_model_list = [m["model_name"] for m in llm_model_list]
|
||||
proxy_model_list = llm_router.get_model_names()
|
||||
|
||||
key_models = get_key_models(
|
||||
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
||||
)
|
||||
|
|
|
@ -86,6 +86,7 @@ from litellm.types.router import (
|
|||
Deployment,
|
||||
DeploymentTypedDict,
|
||||
LiteLLM_Params,
|
||||
LiteLLMParamsTypedDict,
|
||||
ModelGroupInfo,
|
||||
ModelInfo,
|
||||
RetryPolicy,
|
||||
|
@ -4297,7 +4298,9 @@ class Router:
|
|||
return model
|
||||
return None
|
||||
|
||||
def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:
|
||||
def _set_model_group_info(
|
||||
self, model_group: str, user_facing_model_group_name: str
|
||||
) -> Optional[ModelGroupInfo]:
|
||||
"""
|
||||
For a given model group name, return the combined model info
|
||||
|
||||
|
@ -4379,7 +4382,7 @@ class Router:
|
|||
|
||||
if model_group_info is None:
|
||||
model_group_info = ModelGroupInfo(
|
||||
model_group=model_group, providers=[llm_provider], **model_info # type: ignore
|
||||
model_group=user_facing_model_group_name, providers=[llm_provider], **model_info # type: ignore
|
||||
)
|
||||
else:
|
||||
# if max_input_tokens > curr
|
||||
|
@ -4464,6 +4467,26 @@ class Router:
|
|||
|
||||
return model_group_info
|
||||
|
||||
def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:
|
||||
"""
|
||||
For a given model group name, return the combined model info
|
||||
|
||||
Returns:
|
||||
- ModelGroupInfo if able to construct a model group
|
||||
- None if error constructing model group info
|
||||
"""
|
||||
## Check if model group alias
|
||||
if model_group in self.model_group_alias:
|
||||
return self._set_model_group_info(
|
||||
model_group=self.model_group_alias[model_group],
|
||||
user_facing_model_group_name=model_group,
|
||||
)
|
||||
|
||||
## Check if actual model
|
||||
return self._set_model_group_info(
|
||||
model_group=model_group, user_facing_model_group_name=model_group
|
||||
)
|
||||
|
||||
async def get_model_group_usage(
|
||||
self, model_group: str
|
||||
) -> Tuple[Optional[int], Optional[int]]:
|
||||
|
@ -4534,19 +4557,35 @@ class Router:
|
|||
return ids
|
||||
|
||||
def get_model_names(self) -> List[str]:
|
||||
return self.model_names
|
||||
"""
|
||||
Returns all possible model names for router.
|
||||
|
||||
Includes model_group_alias models too.
|
||||
"""
|
||||
return self.model_names + list(self.model_group_alias.keys())
|
||||
|
||||
def get_model_list(
|
||||
self, model_name: Optional[str] = None
|
||||
) -> Optional[List[DeploymentTypedDict]]:
|
||||
if hasattr(self, "model_list"):
|
||||
if model_name is None:
|
||||
return self.model_list
|
||||
|
||||
returned_models: List[DeploymentTypedDict] = []
|
||||
|
||||
for model_alias, model_value in self.model_group_alias.items():
|
||||
model_alias_item = DeploymentTypedDict(
|
||||
model_name=model_alias,
|
||||
litellm_params=LiteLLMParamsTypedDict(model=model_value),
|
||||
)
|
||||
returned_models.append(model_alias_item)
|
||||
|
||||
if model_name is None:
|
||||
returned_models += self.model_list
|
||||
|
||||
return returned_models
|
||||
|
||||
for model in self.model_list:
|
||||
if model["model_name"] == model_name:
|
||||
returned_models.append(model)
|
||||
|
||||
return returned_models
|
||||
return None
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ import litellm
|
|||
from litellm import create_batch, create_file
|
||||
|
||||
|
||||
@pytest.mark.parametrize("provider", ["openai", "azure"])
|
||||
@pytest.mark.parametrize("provider", ["openai"]) # , "azure"
|
||||
def test_create_batch(provider):
|
||||
"""
|
||||
1. Create File for Batch completion
|
||||
|
@ -96,7 +96,7 @@ def test_create_batch(provider):
|
|||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("provider", ["openai", "azure"])
|
||||
@pytest.mark.parametrize("provider", ["openai"]) # "azure"
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_async_create_batch(provider):
|
||||
|
|
|
@ -2396,6 +2396,7 @@ async def test_router_weighted_pick(sync_mode):
|
|||
assert model_id_1_count > model_id_2_count
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Hit azure batch quota limits")
|
||||
@pytest.mark.parametrize("provider", ["azure"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_batch_endpoints(provider):
|
||||
|
|
|
@ -12,6 +12,7 @@ import pytest
|
|||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
|
@ -1266,3 +1267,73 @@ async def test_using_default_working_fallback(sync_mode):
|
|||
)
|
||||
print("got response=", response)
|
||||
assert response is not None
|
||||
|
||||
|
||||
# asyncio.run(test_acompletion_gemini_stream())
|
||||
def mock_post_streaming(url, **kwargs):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 529
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.return_value = {"detail": "Overloaded!"}
|
||||
|
||||
return mock_response
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_streaming_fallbacks(sync_mode):
|
||||
litellm.set_verbose = True
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
|
||||
if sync_mode:
|
||||
client = HTTPHandler(concurrent_limit=1)
|
||||
else:
|
||||
client = AsyncHTTPHandler(concurrent_limit=1)
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "anthropic/claude-3-5-sonnet-20240620",
|
||||
"litellm_params": {
|
||||
"model": "anthropic/claude-3-5-sonnet-20240620",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"mock_response": "Hey, how's it going?",
|
||||
},
|
||||
},
|
||||
],
|
||||
fallbacks=[{"anthropic/claude-3-5-sonnet-20240620": ["gpt-3.5-turbo"]}],
|
||||
num_retries=0,
|
||||
)
|
||||
|
||||
with patch.object(client, "post", side_effect=mock_post_streaming) as mock_client:
|
||||
chunks = []
|
||||
if sync_mode:
|
||||
response = router.completion(
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
stream=True,
|
||||
client=client,
|
||||
)
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
chunks.append(chunk)
|
||||
else:
|
||||
response = await router.acompletion(
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
stream=True,
|
||||
client=client,
|
||||
)
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
chunks.append(chunk)
|
||||
print(f"RETURNED response: {response}")
|
||||
|
||||
mock_client.assert_called_once()
|
||||
print(chunks)
|
||||
assert len(chunks) > 0
|
||||
|
|
|
@ -127,6 +127,7 @@ async def test_completion_sagemaker_messages_api(sync_mode):
|
|||
"sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||
],
|
||||
)
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_completion_sagemaker_stream(sync_mode, model):
|
||||
try:
|
||||
from litellm.tests.test_streaming import streaming_format_tests
|
||||
|
|
|
@ -3144,7 +3144,6 @@ async def test_azure_astreaming_and_function_calling():
|
|||
|
||||
|
||||
def test_completion_claude_3_function_call_with_streaming():
|
||||
litellm.set_verbose = True
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
|
@ -3827,6 +3826,65 @@ def test_unit_test_custom_stream_wrapper_function_call():
|
|||
assert len(new_model.choices[0].delta.tool_calls) > 0
|
||||
|
||||
|
||||
def test_unit_test_perplexity_citations_chunk():
|
||||
"""
|
||||
Test if model returns a tool call, the finish reason is correctly set to 'tool_calls'
|
||||
"""
|
||||
from litellm.types.llms.openai import ChatCompletionDeltaChunk
|
||||
|
||||
litellm.set_verbose = False
|
||||
delta: ChatCompletionDeltaChunk = {
|
||||
"content": "B",
|
||||
"role": "assistant",
|
||||
}
|
||||
chunk = {
|
||||
"id": "xxx",
|
||||
"model": "llama-3.1-sonar-small-128k-online",
|
||||
"created": 1725494279,
|
||||
"usage": {"prompt_tokens": 15, "completion_tokens": 1, "total_tokens": 16},
|
||||
"citations": [
|
||||
"https://x.com/bizzabo?lang=ur",
|
||||
"https://apps.apple.com/my/app/bizzabo/id408705047",
|
||||
"https://www.bizzabo.com/blog/maximize-event-data-strategies-for-success",
|
||||
],
|
||||
"object": "chat.completion",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
"message": {"role": "assistant", "content": "B"},
|
||||
"delta": delta,
|
||||
}
|
||||
],
|
||||
}
|
||||
chunk = litellm.ModelResponse(**chunk, stream=True)
|
||||
|
||||
completion_stream = ModelResponseIterator(model_response=chunk)
|
||||
|
||||
response = litellm.CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model="gpt-3.5-turbo",
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=litellm.litellm_core_utils.litellm_logging.Logging(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey"}],
|
||||
stream=True,
|
||||
call_type="completion",
|
||||
start_time=time.time(),
|
||||
litellm_call_id="12345",
|
||||
function_id="1245",
|
||||
),
|
||||
)
|
||||
|
||||
finish_reason: Optional[str] = None
|
||||
for response_chunk in response:
|
||||
if response_chunk.choices[0].delta.content is not None:
|
||||
print(
|
||||
f"response_chunk.choices[0].delta.content: {response_chunk.choices[0].delta.content}"
|
||||
)
|
||||
assert "citations" in response_chunk
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
|
|
|
@ -1297,3 +1297,7 @@ class CustomStreamingDecoder:
|
|||
self, iterator: Iterator[bytes]
|
||||
) -> Iterator[Optional[Union[GenericStreamingChunk, StreamingChatCompletionChunk]]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StandardPassThroughResponseObject(TypedDict):
|
||||
response: str
|
||||
|
|
|
@ -9845,6 +9845,9 @@ class CustomStreamWrapper:
|
|||
model_response.system_fingerprint = (
|
||||
original_chunk.system_fingerprint
|
||||
)
|
||||
model_response.citations = getattr(
|
||||
original_chunk, "citations", None
|
||||
)
|
||||
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
|
||||
if self.sent_first_chunk is False:
|
||||
model_response.choices[0].delta["role"] = "assistant"
|
||||
|
@ -10460,6 +10463,8 @@ class TextCompletionStreamWrapper:
|
|||
def mock_completion_streaming_obj(
|
||||
model_response, mock_response, model, n: Optional[int] = None
|
||||
):
|
||||
if isinstance(mock_response, litellm.MockException):
|
||||
raise mock_response
|
||||
for i in range(0, len(mock_response), 3):
|
||||
completion_obj = Delta(role="assistant", content=mock_response[i : i + 3])
|
||||
if n is None:
|
||||
|
@ -10481,6 +10486,8 @@ def mock_completion_streaming_obj(
|
|||
async def async_mock_completion_streaming_obj(
|
||||
model_response, mock_response, model, n: Optional[int] = None
|
||||
):
|
||||
if isinstance(mock_response, litellm.MockException):
|
||||
raise mock_response
|
||||
for i in range(0, len(mock_response), 3):
|
||||
completion_obj = Delta(role="assistant", content=mock_response[i : i + 3])
|
||||
if n is None:
|
||||
|
|
|
@ -3408,6 +3408,15 @@
|
|||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"amazon.titan-text-premier-v1:0": {
|
||||
"max_tokens": 32000,
|
||||
"max_input_tokens": 42000,
|
||||
"max_output_tokens": 32000,
|
||||
"input_cost_per_token": 0.0000005,
|
||||
"output_cost_per_token": 0.0000015,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"amazon.titan-embed-text-v1": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 8192,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue