mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
LiteLLM Minor Fixes and Improvements (08/06/2024) (#5567)
* fix(utils.py): return citations for perplexity streaming Fixes https://github.com/BerriAI/litellm/issues/5535 * fix(anthropic/chat.py): support fallbacks for anthropic streaming (#5542) * fix(anthropic/chat.py): support fallbacks for anthropic streaming Fixes https://github.com/BerriAI/litellm/issues/5512 * fix(anthropic/chat.py): use module level http client if none given (prevents early client closure) * fix: fix linting errors * fix(http_handler.py): fix raise_for_status error handling * test: retry flaky test * fix otel type * fix(bedrock/embed): fix error raising * test(test_openai_batches_and_files.py): skip azure batches test (for now) quota exceeded * fix(test_router.py): skip azure batch route test (for now) - hit batch quota limits --------- Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> * All `model_group_alias` should show up in `/models`, `/model/info` , `/model_group/info` (#5539) * fix(router.py): support returning model_alias model names in `/v1/models` * fix(proxy_server.py): support returning model alias'es on `/model/info` * feat(router.py): support returning model group alias for `/model_group/info` * fix(proxy_server.py): fix linting errors * fix(proxy_server.py): fix linting errors * build(model_prices_and_context_window.json): add amazon titan text premier pricing information Closes https://github.com/BerriAI/litellm/issues/5560 * feat(litellm_logging.py): log standard logging response object for pass through endpoints. Allows bedrock /invoke agent calls to be correctly logged to langfuse + s3 * fix(success_handler.py): fix linting error * fix(success_handler.py): fix linting errors * fix(team_endpoints.py): Allows admin to update team member budgets --------- Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
This commit is contained in:
parent
e4dcd6f745
commit
72e961af3c
25 changed files with 509 additions and 99 deletions
|
@ -22,7 +22,8 @@
|
||||||
"ms-python.python",
|
"ms-python.python",
|
||||||
"ms-python.vscode-pylance",
|
"ms-python.vscode-pylance",
|
||||||
"GitHub.copilot",
|
"GitHub.copilot",
|
||||||
"GitHub.copilot-chat"
|
"GitHub.copilot-chat",
|
||||||
|
"ms-python.autopep8"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
repos:
|
repos:
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
# - id: mypy
|
- id: mypy
|
||||||
# name: mypy
|
name: mypy
|
||||||
# entry: python3 -m mypy --ignore-missing-imports
|
entry: python3 -m mypy --ignore-missing-imports
|
||||||
# language: system
|
language: system
|
||||||
# types: [python]
|
types: [python]
|
||||||
# files: ^litellm/
|
files: ^litellm/
|
||||||
- id: isort
|
- id: isort
|
||||||
name: isort
|
name: isort
|
||||||
entry: isort
|
entry: isort
|
||||||
|
|
|
@ -208,6 +208,14 @@ class LangFuseLogger:
|
||||||
):
|
):
|
||||||
input = prompt
|
input = prompt
|
||||||
output = response_obj["text"]
|
output = response_obj["text"]
|
||||||
|
elif (
|
||||||
|
kwargs.get("call_type") is not None
|
||||||
|
and kwargs.get("call_type") == "pass_through_endpoint"
|
||||||
|
and response_obj is not None
|
||||||
|
and isinstance(response_obj, dict)
|
||||||
|
):
|
||||||
|
input = prompt
|
||||||
|
output = response_obj.get("response", "")
|
||||||
print_verbose(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}")
|
print_verbose(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}")
|
||||||
trace_id = None
|
trace_id = None
|
||||||
generation_id = None
|
generation_id = None
|
||||||
|
|
|
@ -101,12 +101,6 @@ class S3Logger:
|
||||||
metadata = (
|
metadata = (
|
||||||
litellm_params.get("metadata", {}) or {}
|
litellm_params.get("metadata", {}) or {}
|
||||||
) # if litellm_params['metadata'] == None
|
) # if litellm_params['metadata'] == None
|
||||||
messages = kwargs.get("messages")
|
|
||||||
optional_params = kwargs.get("optional_params", {})
|
|
||||||
call_type = kwargs.get("call_type", "litellm.completion")
|
|
||||||
cache_hit = kwargs.get("cache_hit", False)
|
|
||||||
usage = response_obj["usage"]
|
|
||||||
id = response_obj.get("id", str(uuid.uuid4()))
|
|
||||||
|
|
||||||
# Clean Metadata before logging - never log raw metadata
|
# Clean Metadata before logging - never log raw metadata
|
||||||
# the raw metadata can contain circular references which leads to infinite recursion
|
# the raw metadata can contain circular references which leads to infinite recursion
|
||||||
|
@ -171,5 +165,5 @@ class S3Logger:
|
||||||
print_verbose(f"s3 Layer Logging - final response object: {response_obj}")
|
print_verbose(f"s3 Layer Logging - final response object: {response_obj}")
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.debug(f"s3 Layer Error - {str(e)}\n{traceback.format_exc()}")
|
verbose_logger.exception(f"s3 Layer Error - {str(e)}")
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -41,6 +41,7 @@ from litellm.types.utils import (
|
||||||
StandardLoggingMetadata,
|
StandardLoggingMetadata,
|
||||||
StandardLoggingModelInformation,
|
StandardLoggingModelInformation,
|
||||||
StandardLoggingPayload,
|
StandardLoggingPayload,
|
||||||
|
StandardPassThroughResponseObject,
|
||||||
TextCompletionResponse,
|
TextCompletionResponse,
|
||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
)
|
)
|
||||||
|
@ -534,7 +535,9 @@ class Logging:
|
||||||
"""
|
"""
|
||||||
## RESPONSE COST ##
|
## RESPONSE COST ##
|
||||||
custom_pricing = use_custom_pricing_for_model(
|
custom_pricing = use_custom_pricing_for_model(
|
||||||
litellm_params=self.litellm_params
|
litellm_params=(
|
||||||
|
self.litellm_params if hasattr(self, "litellm_params") else None
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if cache_hit is None:
|
if cache_hit is None:
|
||||||
|
@ -611,6 +614,17 @@ class Logging:
|
||||||
] = result._hidden_params
|
] = result._hidden_params
|
||||||
## STANDARDIZED LOGGING PAYLOAD
|
## STANDARDIZED LOGGING PAYLOAD
|
||||||
|
|
||||||
|
self.model_call_details["standard_logging_object"] = (
|
||||||
|
get_standard_logging_object_payload(
|
||||||
|
kwargs=self.model_call_details,
|
||||||
|
init_response_obj=result,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
logging_obj=self,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(result, dict): # pass-through endpoints
|
||||||
|
## STANDARDIZED LOGGING PAYLOAD
|
||||||
self.model_call_details["standard_logging_object"] = (
|
self.model_call_details["standard_logging_object"] = (
|
||||||
get_standard_logging_object_payload(
|
get_standard_logging_object_payload(
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
|
@ -2271,6 +2285,8 @@ def get_standard_logging_object_payload(
|
||||||
elif isinstance(init_response_obj, BaseModel):
|
elif isinstance(init_response_obj, BaseModel):
|
||||||
response_obj = init_response_obj.model_dump()
|
response_obj = init_response_obj.model_dump()
|
||||||
hidden_params = getattr(init_response_obj, "_hidden_params", None)
|
hidden_params = getattr(init_response_obj, "_hidden_params", None)
|
||||||
|
elif isinstance(init_response_obj, dict):
|
||||||
|
response_obj = init_response_obj
|
||||||
else:
|
else:
|
||||||
response_obj = {}
|
response_obj = {}
|
||||||
# standardize this function to be used across, s3, dynamoDB, langfuse logging
|
# standardize this function to be used across, s3, dynamoDB, langfuse logging
|
||||||
|
|
|
@ -674,7 +674,7 @@ async def make_call(
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
):
|
):
|
||||||
if client is None:
|
if client is None:
|
||||||
client = _get_async_httpx_client() # Create a new client if none provided
|
client = litellm.module_level_aclient
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
|
@ -690,11 +690,6 @@ async def make_call(
|
||||||
raise e
|
raise e
|
||||||
raise AnthropicError(status_code=500, message=str(e))
|
raise AnthropicError(status_code=500, message=str(e))
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise AnthropicError(
|
|
||||||
status_code=response.status_code, message=await response.aread()
|
|
||||||
)
|
|
||||||
|
|
||||||
completion_stream = ModelResponseIterator(
|
completion_stream = ModelResponseIterator(
|
||||||
streaming_response=response.aiter_lines(), sync_stream=False
|
streaming_response=response.aiter_lines(), sync_stream=False
|
||||||
)
|
)
|
||||||
|
@ -721,7 +716,7 @@ def make_sync_call(
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
):
|
):
|
||||||
if client is None:
|
if client is None:
|
||||||
client = HTTPHandler() # Create a new client if none provided
|
client = litellm.module_level_client # re-use a module level client
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = client.post(
|
response = client.post(
|
||||||
|
@ -869,6 +864,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
client: Optional[AsyncHTTPHandler],
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
@ -882,19 +878,18 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
):
|
):
|
||||||
data["stream"] = True
|
data["stream"] = True
|
||||||
|
|
||||||
|
completion_stream = await make_call(
|
||||||
|
client=client,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=None,
|
completion_stream=completion_stream,
|
||||||
make_call=partial(
|
|
||||||
make_call,
|
|
||||||
client=None,
|
|
||||||
api_base=api_base,
|
|
||||||
headers=headers,
|
|
||||||
data=json.dumps(data),
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
timeout=timeout,
|
|
||||||
),
|
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="anthropic",
|
custom_llm_provider="anthropic",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
@ -1080,6 +1075,11 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
client=(
|
||||||
|
client
|
||||||
|
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.acompletion_function(
|
return self.acompletion_function(
|
||||||
|
@ -1105,33 +1105,32 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
if client is None or not isinstance(client, HTTPHandler):
|
|
||||||
client = HTTPHandler(timeout=timeout) # type: ignore
|
|
||||||
else:
|
|
||||||
client = client
|
|
||||||
if (
|
if (
|
||||||
stream is True
|
stream is True
|
||||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||||
data["stream"] = stream
|
data["stream"] = stream
|
||||||
|
completion_stream = make_sync_call(
|
||||||
|
client=client,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers, # type: ignore
|
||||||
|
data=json.dumps(data),
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
return CustomStreamWrapper(
|
return CustomStreamWrapper(
|
||||||
completion_stream=None,
|
completion_stream=completion_stream,
|
||||||
make_call=partial(
|
|
||||||
make_sync_call,
|
|
||||||
client=None,
|
|
||||||
api_base=api_base,
|
|
||||||
headers=headers, # type: ignore
|
|
||||||
data=json.dumps(data),
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
timeout=timeout,
|
|
||||||
),
|
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="anthropic",
|
custom_llm_provider="anthropic",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
client = HTTPHandler(timeout=timeout) # type: ignore
|
||||||
|
else:
|
||||||
|
client = client
|
||||||
response = client.post(
|
response = client.post(
|
||||||
api_base, headers=headers, data=json.dumps(data), timeout=timeout
|
api_base, headers=headers, data=json.dumps(data), timeout=timeout
|
||||||
)
|
)
|
||||||
|
|
|
@ -110,7 +110,7 @@ class BedrockEmbedding(BaseAWSLLM):
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except httpx.HTTPStatusError as err:
|
except httpx.HTTPStatusError as err:
|
||||||
error_code = err.response.status_code
|
error_code = err.response.status_code
|
||||||
raise BedrockError(status_code=error_code, message=response.text)
|
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
|
|
@ -37,16 +37,10 @@ class AsyncHTTPHandler:
|
||||||
|
|
||||||
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
||||||
# /path/to/certificate.pem
|
# /path/to/certificate.pem
|
||||||
ssl_verify = os.getenv(
|
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
|
||||||
"SSL_VERIFY",
|
|
||||||
litellm.ssl_verify
|
|
||||||
)
|
|
||||||
# An SSL certificate used by the requested host to authenticate the client.
|
# An SSL certificate used by the requested host to authenticate the client.
|
||||||
# /path/to/client.pem
|
# /path/to/client.pem
|
||||||
cert = os.getenv(
|
cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate)
|
||||||
"SSL_CERTIFICATE",
|
|
||||||
litellm.ssl_certificate
|
|
||||||
)
|
|
||||||
|
|
||||||
if timeout is None:
|
if timeout is None:
|
||||||
timeout = _DEFAULT_TIMEOUT
|
timeout = _DEFAULT_TIMEOUT
|
||||||
|
@ -277,16 +271,10 @@ class HTTPHandler:
|
||||||
|
|
||||||
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
||||||
# /path/to/certificate.pem
|
# /path/to/certificate.pem
|
||||||
ssl_verify = os.getenv(
|
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
|
||||||
"SSL_VERIFY",
|
|
||||||
litellm.ssl_verify
|
|
||||||
)
|
|
||||||
# An SSL certificate used by the requested host to authenticate the client.
|
# An SSL certificate used by the requested host to authenticate the client.
|
||||||
# /path/to/client.pem
|
# /path/to/client.pem
|
||||||
cert = os.getenv(
|
cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate)
|
||||||
"SSL_CERTIFICATE",
|
|
||||||
litellm.ssl_certificate
|
|
||||||
)
|
|
||||||
|
|
||||||
if client is None:
|
if client is None:
|
||||||
# Create a client with a connection pool
|
# Create a client with a connection pool
|
||||||
|
@ -334,6 +322,7 @@ class HTTPHandler:
|
||||||
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
|
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
|
||||||
)
|
)
|
||||||
response = self.client.send(req, stream=stream)
|
response = self.client.send(req, stream=stream)
|
||||||
|
response.raise_for_status()
|
||||||
return response
|
return response
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
raise litellm.Timeout(
|
raise litellm.Timeout(
|
||||||
|
@ -341,6 +330,13 @@ class HTTPHandler:
|
||||||
model="default-model-name",
|
model="default-model-name",
|
||||||
llm_provider="litellm-httpx-handler",
|
llm_provider="litellm-httpx-handler",
|
||||||
)
|
)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
setattr(e, "status_code", e.response.status_code)
|
||||||
|
if stream is True:
|
||||||
|
setattr(e, "message", e.response.read())
|
||||||
|
else:
|
||||||
|
setattr(e, "message", e.response.text)
|
||||||
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -375,7 +371,6 @@ class HTTPHandler:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
try:
|
try:
|
||||||
self.close()
|
self.close()
|
||||||
|
@ -437,4 +432,4 @@ def _get_httpx_client(params: Optional[dict] = None) -> HTTPHandler:
|
||||||
_new_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
_new_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||||
|
|
||||||
litellm.in_memory_llm_clients_cache[_cache_key_name] = _new_client
|
litellm.in_memory_llm_clients_cache[_cache_key_name] = _new_client
|
||||||
return _new_client
|
return _new_client
|
||||||
|
|
|
@ -534,6 +534,15 @@ def mock_completion(
|
||||||
model=model, # type: ignore
|
model=model, # type: ignore
|
||||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||||
)
|
)
|
||||||
|
elif isinstance(mock_response, str) and mock_response.startswith(
|
||||||
|
"Exception: mock_streaming_error"
|
||||||
|
):
|
||||||
|
mock_response = litellm.MockException(
|
||||||
|
message="This is a mock error raised mid-stream",
|
||||||
|
llm_provider="anthropic",
|
||||||
|
model=model,
|
||||||
|
status_code=529,
|
||||||
|
)
|
||||||
time_delay = kwargs.get("mock_delay", None)
|
time_delay = kwargs.get("mock_delay", None)
|
||||||
if time_delay is not None:
|
if time_delay is not None:
|
||||||
time.sleep(time_delay)
|
time.sleep(time_delay)
|
||||||
|
@ -561,6 +570,8 @@ def mock_completion(
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider="openai",
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
)
|
)
|
||||||
|
if isinstance(mock_response, litellm.MockException):
|
||||||
|
raise mock_response
|
||||||
if n is None:
|
if n is None:
|
||||||
model_response.choices[0].message.content = mock_response # type: ignore
|
model_response.choices[0].message.content = mock_response # type: ignore
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -3408,6 +3408,15 @@
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"amazon.titan-text-premier-v1:0": {
|
||||||
|
"max_tokens": 32000,
|
||||||
|
"max_input_tokens": 42000,
|
||||||
|
"max_output_tokens": 32000,
|
||||||
|
"input_cost_per_token": 0.0000005,
|
||||||
|
"output_cost_per_token": 0.0000015,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
"amazon.titan-embed-text-v1": {
|
"amazon.titan-embed-text-v1": {
|
||||||
"max_tokens": 8192,
|
"max_tokens": 8192,
|
||||||
"max_input_tokens": 8192,
|
"max_input_tokens": 8192,
|
||||||
|
|
|
@ -1,7 +1,16 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: "*"
|
- model_name: "anthropic/claude-3-5-sonnet-20240620"
|
||||||
|
litellm_params:
|
||||||
|
model: anthropic/claude-3-5-sonnet-20240620
|
||||||
|
# api_base: http://0.0.0.0:9000
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/*
|
model: openai/*
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
default_max_internal_user_budget: 2
|
success_callback: ["s3"]
|
||||||
|
s3_callback_params:
|
||||||
|
s3_bucket_name: litellm-logs # AWS Bucket Name for S3
|
||||||
|
s3_region_name: us-west-2 # AWS Region Name for S3
|
||||||
|
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
|
||||||
|
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
|
|
@ -872,6 +872,17 @@ class TeamMemberDeleteRequest(LiteLLMBase):
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class TeamMemberUpdateRequest(TeamMemberDeleteRequest):
|
||||||
|
max_budget_in_team: float
|
||||||
|
|
||||||
|
|
||||||
|
class TeamMemberUpdateResponse(LiteLLMBase):
|
||||||
|
team_id: str
|
||||||
|
user_id: str
|
||||||
|
user_email: Optional[str] = None
|
||||||
|
max_budget_in_team: float
|
||||||
|
|
||||||
|
|
||||||
class UpdateTeamRequest(LiteLLMBase):
|
class UpdateTeamRequest(LiteLLMBase):
|
||||||
"""
|
"""
|
||||||
UpdateTeamRequest, used by /team/update when you need to update a team
|
UpdateTeamRequest, used by /team/update when you need to update a team
|
||||||
|
@ -1854,3 +1865,10 @@ class LiteLLM_TeamMembership(LiteLLMBase):
|
||||||
class TeamAddMemberResponse(LiteLLM_TeamTable):
|
class TeamAddMemberResponse(LiteLLM_TeamTable):
|
||||||
updated_users: List[LiteLLM_UserTable]
|
updated_users: List[LiteLLM_UserTable]
|
||||||
updated_team_memberships: List[LiteLLM_TeamMembership]
|
updated_team_memberships: List[LiteLLM_TeamMembership]
|
||||||
|
|
||||||
|
|
||||||
|
class TeamInfoResponseObject(TypedDict):
|
||||||
|
team_id: str
|
||||||
|
team_info: TeamBase
|
||||||
|
keys: List
|
||||||
|
team_memberships: List[LiteLLM_TeamMembership]
|
||||||
|
|
|
@ -28,8 +28,11 @@ from litellm.proxy._types import (
|
||||||
ProxyErrorTypes,
|
ProxyErrorTypes,
|
||||||
ProxyException,
|
ProxyException,
|
||||||
TeamAddMemberResponse,
|
TeamAddMemberResponse,
|
||||||
|
TeamInfoResponseObject,
|
||||||
TeamMemberAddRequest,
|
TeamMemberAddRequest,
|
||||||
TeamMemberDeleteRequest,
|
TeamMemberDeleteRequest,
|
||||||
|
TeamMemberUpdateRequest,
|
||||||
|
TeamMemberUpdateResponse,
|
||||||
UpdateTeamRequest,
|
UpdateTeamRequest,
|
||||||
UserAPIKeyAuth,
|
UserAPIKeyAuth,
|
||||||
)
|
)
|
||||||
|
@ -750,6 +753,131 @@ async def team_member_delete(
|
||||||
return existing_team_row
|
return existing_team_row
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/team/member_update",
|
||||||
|
tags=["team management"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
response_model=TeamMemberUpdateResponse,
|
||||||
|
)
|
||||||
|
@management_endpoint_wrapper
|
||||||
|
async def team_member_update(
|
||||||
|
data: TeamMemberUpdateRequest,
|
||||||
|
http_request: Request,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
[BETA]
|
||||||
|
|
||||||
|
Update team member budgets
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import (
|
||||||
|
_duration_in_seconds,
|
||||||
|
create_audit_log_for_update,
|
||||||
|
litellm_proxy_admin_name,
|
||||||
|
prisma_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
if prisma_client is None:
|
||||||
|
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||||
|
|
||||||
|
if data.team_id is None:
|
||||||
|
raise HTTPException(status_code=400, detail={"error": "No team id passed in"})
|
||||||
|
|
||||||
|
if data.user_id is None and data.user_email is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={"error": "Either user_id or user_email needs to be passed in"},
|
||||||
|
)
|
||||||
|
|
||||||
|
_existing_team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||||
|
where={"team_id": data.team_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
if _existing_team_row is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={"error": "Team id={} does not exist in db".format(data.team_id)},
|
||||||
|
)
|
||||||
|
existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump())
|
||||||
|
|
||||||
|
## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN
|
||||||
|
|
||||||
|
if (
|
||||||
|
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||||
|
and not _is_user_team_admin(
|
||||||
|
user_api_key_dict=user_api_key_dict, team_obj=existing_team_row
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail={
|
||||||
|
"error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format(
|
||||||
|
"/team/member_delete", existing_team_row.team_id
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
returned_team_info: TeamInfoResponseObject = await team_info(
|
||||||
|
http_request=http_request,
|
||||||
|
team_id=data.team_id,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
## get user id
|
||||||
|
received_user_id: Optional[str] = None
|
||||||
|
if data.user_id is not None:
|
||||||
|
received_user_id = data.user_id
|
||||||
|
elif data.user_email is not None:
|
||||||
|
for member in returned_team_info["team_info"].members_with_roles:
|
||||||
|
if member.user_email is not None and member.user_email == data.user_email:
|
||||||
|
received_user_id = member.user_id
|
||||||
|
break
|
||||||
|
|
||||||
|
if received_user_id is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": "User id doesn't exist in team table. Data={}".format(data)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
## find the relevant team membership
|
||||||
|
identified_budget_id: Optional[str] = None
|
||||||
|
for tm in returned_team_info["team_memberships"]:
|
||||||
|
if tm.user_id == received_user_id:
|
||||||
|
identified_budget_id = tm.budget_id
|
||||||
|
break
|
||||||
|
|
||||||
|
### upsert new budget
|
||||||
|
if identified_budget_id is None:
|
||||||
|
new_budget = await prisma_client.db.litellm_budgettable.create(
|
||||||
|
data={
|
||||||
|
"max_budget": data.max_budget_in_team,
|
||||||
|
"created_by": user_api_key_dict.user_id or "",
|
||||||
|
"updated_by": user_api_key_dict.user_id or "",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await prisma_client.db.litellm_teammembership.create(
|
||||||
|
data={
|
||||||
|
"team_id": data.team_id,
|
||||||
|
"user_id": received_user_id,
|
||||||
|
"budget_id": new_budget.budget_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await prisma_client.db.litellm_budgettable.update(
|
||||||
|
where={"budget_id": identified_budget_id},
|
||||||
|
data={"max_budget": data.max_budget_in_team},
|
||||||
|
)
|
||||||
|
|
||||||
|
return TeamMemberUpdateResponse(
|
||||||
|
team_id=data.team_id,
|
||||||
|
user_id=received_user_id,
|
||||||
|
user_email=data.user_email,
|
||||||
|
max_budget_in_team=data.max_budget_in_team,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/team/delete", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
|
"/team/delete", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
|
||||||
)
|
)
|
||||||
|
@ -937,12 +1065,18 @@ async def team_info(
|
||||||
where={"team_id": team_id},
|
where={"team_id": team_id},
|
||||||
include={"litellm_budget_table": True},
|
include={"litellm_budget_table": True},
|
||||||
)
|
)
|
||||||
return {
|
|
||||||
"team_id": team_id,
|
returned_tm: List[LiteLLM_TeamMembership] = []
|
||||||
"team_info": team_info,
|
for tm in team_memberships:
|
||||||
"keys": keys,
|
returned_tm.append(LiteLLM_TeamMembership(**tm.model_dump()))
|
||||||
"team_memberships": team_memberships,
|
|
||||||
}
|
response_object = TeamInfoResponseObject(
|
||||||
|
team_id=team_id,
|
||||||
|
team_info=team_info,
|
||||||
|
keys=keys,
|
||||||
|
team_memberships=returned_tm,
|
||||||
|
)
|
||||||
|
return response_object
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.error(
|
||||||
|
|
|
@ -359,7 +359,7 @@ async def pass_through_request(
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
logging_obj = Logging(
|
logging_obj = Logging(
|
||||||
model="unknown",
|
model="unknown",
|
||||||
messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
|
messages=[{"role": "user", "content": json.dumps(_parsed_body)}],
|
||||||
stream=False,
|
stream=False,
|
||||||
call_type="pass_through_endpoint",
|
call_type="pass_through_endpoint",
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
@ -414,7 +414,7 @@ async def pass_through_request(
|
||||||
logging_url = str(url) + "?" + requested_query_params_str
|
logging_url = str(url) + "?" + requested_query_params_str
|
||||||
|
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
|
input=[{"role": "user", "content": json.dumps(_parsed_body)}],
|
||||||
api_key="",
|
api_key="",
|
||||||
additional_args={
|
additional_args={
|
||||||
"complete_input_dict": _parsed_body,
|
"complete_input_dict": _parsed_body,
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
|
import json
|
||||||
import re
|
import re
|
||||||
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
@ -10,6 +12,7 @@ from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_stu
|
||||||
VertexLLM,
|
VertexLLM,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
from litellm.types.utils import StandardPassThroughResponseObject
|
||||||
|
|
||||||
|
|
||||||
class PassThroughEndpointLogging:
|
class PassThroughEndpointLogging:
|
||||||
|
@ -43,8 +46,24 @@ class PassThroughEndpointLogging:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
standard_logging_response_object = StandardPassThroughResponseObject(
|
||||||
|
response=httpx_response.text
|
||||||
|
)
|
||||||
|
threading.Thread(
|
||||||
|
target=logging_obj.success_handler,
|
||||||
|
args=(
|
||||||
|
standard_logging_response_object,
|
||||||
|
start_time,
|
||||||
|
end_time,
|
||||||
|
cache_hit,
|
||||||
|
),
|
||||||
|
).start()
|
||||||
await logging_obj.async_success_handler(
|
await logging_obj.async_success_handler(
|
||||||
result="",
|
result=(
|
||||||
|
json.dumps(result)
|
||||||
|
if isinstance(result, dict)
|
||||||
|
else standard_logging_response_object
|
||||||
|
),
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
cache_hit=False,
|
cache_hit=False,
|
||||||
|
|
|
@ -3005,13 +3005,13 @@ def model_list(
|
||||||
|
|
||||||
This is just for compatibility with openai projects like aider.
|
This is just for compatibility with openai projects like aider.
|
||||||
"""
|
"""
|
||||||
global llm_model_list, general_settings
|
global llm_model_list, general_settings, llm_router
|
||||||
all_models = []
|
all_models = []
|
||||||
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
|
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
|
||||||
if llm_model_list is None:
|
if llm_router is None:
|
||||||
proxy_model_list = []
|
proxy_model_list = []
|
||||||
else:
|
else:
|
||||||
proxy_model_list = [m["model_name"] for m in llm_model_list]
|
proxy_model_list = llm_router.get_model_names()
|
||||||
key_models = get_key_models(
|
key_models = get_key_models(
|
||||||
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
||||||
)
|
)
|
||||||
|
@ -7503,10 +7503,11 @@ async def model_info_v1(
|
||||||
|
|
||||||
all_models: List[dict] = []
|
all_models: List[dict] = []
|
||||||
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
|
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
|
||||||
if llm_model_list is None:
|
if llm_router is None:
|
||||||
proxy_model_list = []
|
proxy_model_list = []
|
||||||
else:
|
else:
|
||||||
proxy_model_list = [m["model_name"] for m in llm_model_list]
|
proxy_model_list = llm_router.get_model_names()
|
||||||
|
|
||||||
key_models = get_key_models(
|
key_models = get_key_models(
|
||||||
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
||||||
)
|
)
|
||||||
|
@ -7523,8 +7524,14 @@ async def model_info_v1(
|
||||||
|
|
||||||
if len(all_models_str) > 0:
|
if len(all_models_str) > 0:
|
||||||
model_names = all_models_str
|
model_names = all_models_str
|
||||||
_relevant_models = [m for m in llm_model_list if m["model_name"] in model_names]
|
llm_model_list = llm_router.get_model_list()
|
||||||
all_models = copy.deepcopy(_relevant_models)
|
if llm_model_list is not None:
|
||||||
|
_relevant_models = [
|
||||||
|
m for m in llm_model_list if m["model_name"] in model_names
|
||||||
|
]
|
||||||
|
all_models = copy.deepcopy(_relevant_models) # type: ignore
|
||||||
|
else:
|
||||||
|
all_models = []
|
||||||
|
|
||||||
for model in all_models:
|
for model in all_models:
|
||||||
# provided model_info in config.yaml
|
# provided model_info in config.yaml
|
||||||
|
@ -7590,12 +7597,12 @@ async def model_group_info(
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail={"error": "LLM Router is not loaded in"}
|
status_code=500, detail={"error": "LLM Router is not loaded in"}
|
||||||
)
|
)
|
||||||
all_models: List[dict] = []
|
|
||||||
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
|
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
|
||||||
if llm_model_list is None:
|
if llm_router is None:
|
||||||
proxy_model_list = []
|
proxy_model_list = []
|
||||||
else:
|
else:
|
||||||
proxy_model_list = [m["model_name"] for m in llm_model_list]
|
proxy_model_list = llm_router.get_model_names()
|
||||||
|
|
||||||
key_models = get_key_models(
|
key_models = get_key_models(
|
||||||
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
||||||
)
|
)
|
||||||
|
|
|
@ -86,6 +86,7 @@ from litellm.types.router import (
|
||||||
Deployment,
|
Deployment,
|
||||||
DeploymentTypedDict,
|
DeploymentTypedDict,
|
||||||
LiteLLM_Params,
|
LiteLLM_Params,
|
||||||
|
LiteLLMParamsTypedDict,
|
||||||
ModelGroupInfo,
|
ModelGroupInfo,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
RetryPolicy,
|
RetryPolicy,
|
||||||
|
@ -4297,7 +4298,9 @@ class Router:
|
||||||
return model
|
return model
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:
|
def _set_model_group_info(
|
||||||
|
self, model_group: str, user_facing_model_group_name: str
|
||||||
|
) -> Optional[ModelGroupInfo]:
|
||||||
"""
|
"""
|
||||||
For a given model group name, return the combined model info
|
For a given model group name, return the combined model info
|
||||||
|
|
||||||
|
@ -4379,7 +4382,7 @@ class Router:
|
||||||
|
|
||||||
if model_group_info is None:
|
if model_group_info is None:
|
||||||
model_group_info = ModelGroupInfo(
|
model_group_info = ModelGroupInfo(
|
||||||
model_group=model_group, providers=[llm_provider], **model_info # type: ignore
|
model_group=user_facing_model_group_name, providers=[llm_provider], **model_info # type: ignore
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# if max_input_tokens > curr
|
# if max_input_tokens > curr
|
||||||
|
@ -4464,6 +4467,26 @@ class Router:
|
||||||
|
|
||||||
return model_group_info
|
return model_group_info
|
||||||
|
|
||||||
|
def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:
|
||||||
|
"""
|
||||||
|
For a given model group name, return the combined model info
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- ModelGroupInfo if able to construct a model group
|
||||||
|
- None if error constructing model group info
|
||||||
|
"""
|
||||||
|
## Check if model group alias
|
||||||
|
if model_group in self.model_group_alias:
|
||||||
|
return self._set_model_group_info(
|
||||||
|
model_group=self.model_group_alias[model_group],
|
||||||
|
user_facing_model_group_name=model_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
## Check if actual model
|
||||||
|
return self._set_model_group_info(
|
||||||
|
model_group=model_group, user_facing_model_group_name=model_group
|
||||||
|
)
|
||||||
|
|
||||||
async def get_model_group_usage(
|
async def get_model_group_usage(
|
||||||
self, model_group: str
|
self, model_group: str
|
||||||
) -> Tuple[Optional[int], Optional[int]]:
|
) -> Tuple[Optional[int], Optional[int]]:
|
||||||
|
@ -4534,19 +4557,35 @@ class Router:
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
def get_model_names(self) -> List[str]:
|
def get_model_names(self) -> List[str]:
|
||||||
return self.model_names
|
"""
|
||||||
|
Returns all possible model names for router.
|
||||||
|
|
||||||
|
Includes model_group_alias models too.
|
||||||
|
"""
|
||||||
|
return self.model_names + list(self.model_group_alias.keys())
|
||||||
|
|
||||||
def get_model_list(
|
def get_model_list(
|
||||||
self, model_name: Optional[str] = None
|
self, model_name: Optional[str] = None
|
||||||
) -> Optional[List[DeploymentTypedDict]]:
|
) -> Optional[List[DeploymentTypedDict]]:
|
||||||
if hasattr(self, "model_list"):
|
if hasattr(self, "model_list"):
|
||||||
if model_name is None:
|
|
||||||
return self.model_list
|
|
||||||
|
|
||||||
returned_models: List[DeploymentTypedDict] = []
|
returned_models: List[DeploymentTypedDict] = []
|
||||||
|
|
||||||
|
for model_alias, model_value in self.model_group_alias.items():
|
||||||
|
model_alias_item = DeploymentTypedDict(
|
||||||
|
model_name=model_alias,
|
||||||
|
litellm_params=LiteLLMParamsTypedDict(model=model_value),
|
||||||
|
)
|
||||||
|
returned_models.append(model_alias_item)
|
||||||
|
|
||||||
|
if model_name is None:
|
||||||
|
returned_models += self.model_list
|
||||||
|
|
||||||
|
return returned_models
|
||||||
|
|
||||||
for model in self.model_list:
|
for model in self.model_list:
|
||||||
if model["model_name"] == model_name:
|
if model["model_name"] == model_name:
|
||||||
returned_models.append(model)
|
returned_models.append(model)
|
||||||
|
|
||||||
return returned_models
|
return returned_models
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ import litellm
|
||||||
from litellm import create_batch, create_file
|
from litellm import create_batch, create_file
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("provider", ["openai", "azure"])
|
@pytest.mark.parametrize("provider", ["openai"]) # , "azure"
|
||||||
def test_create_batch(provider):
|
def test_create_batch(provider):
|
||||||
"""
|
"""
|
||||||
1. Create File for Batch completion
|
1. Create File for Batch completion
|
||||||
|
@ -96,7 +96,7 @@ def test_create_batch(provider):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("provider", ["openai", "azure"])
|
@pytest.mark.parametrize("provider", ["openai"]) # "azure"
|
||||||
@pytest.mark.asyncio()
|
@pytest.mark.asyncio()
|
||||||
@pytest.mark.flaky(retries=3, delay=1)
|
@pytest.mark.flaky(retries=3, delay=1)
|
||||||
async def test_async_create_batch(provider):
|
async def test_async_create_batch(provider):
|
||||||
|
|
|
@ -2396,6 +2396,7 @@ async def test_router_weighted_pick(sync_mode):
|
||||||
assert model_id_1_count > model_id_2_count
|
assert model_id_1_count > model_id_2_count
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Hit azure batch quota limits")
|
||||||
@pytest.mark.parametrize("provider", ["azure"])
|
@pytest.mark.parametrize("provider", ["azure"])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_router_batch_endpoints(provider):
|
async def test_router_batch_endpoints(provider):
|
||||||
|
|
|
@ -12,6 +12,7 @@ import pytest
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
|
@ -1266,3 +1267,73 @@ async def test_using_default_working_fallback(sync_mode):
|
||||||
)
|
)
|
||||||
print("got response=", response)
|
print("got response=", response)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
|
|
||||||
|
|
||||||
|
# asyncio.run(test_acompletion_gemini_stream())
|
||||||
|
def mock_post_streaming(url, **kwargs):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 529
|
||||||
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
mock_response.return_value = {"detail": "Overloaded!"}
|
||||||
|
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anthropic_streaming_fallbacks(sync_mode):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
|
||||||
|
if sync_mode:
|
||||||
|
client = HTTPHandler(concurrent_limit=1)
|
||||||
|
else:
|
||||||
|
client = AsyncHTTPHandler(concurrent_limit=1)
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"mock_response": "Hey, how's it going?",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
fallbacks=[{"anthropic/claude-3-5-sonnet-20240620": ["gpt-3.5-turbo"]}],
|
||||||
|
num_retries=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(client, "post", side_effect=mock_post_streaming) as mock_client:
|
||||||
|
chunks = []
|
||||||
|
if sync_mode:
|
||||||
|
response = router.completion(
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
stream=True,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
chunks.append(chunk)
|
||||||
|
else:
|
||||||
|
response = await router.acompletion(
|
||||||
|
model="anthropic/claude-3-5-sonnet-20240620",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
stream=True,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
async for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
chunks.append(chunk)
|
||||||
|
print(f"RETURNED response: {response}")
|
||||||
|
|
||||||
|
mock_client.assert_called_once()
|
||||||
|
print(chunks)
|
||||||
|
assert len(chunks) > 0
|
||||||
|
|
|
@ -127,6 +127,7 @@ async def test_completion_sagemaker_messages_api(sync_mode):
|
||||||
"sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
"sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@pytest.mark.flaky(retries=3, delay=1)
|
||||||
async def test_completion_sagemaker_stream(sync_mode, model):
|
async def test_completion_sagemaker_stream(sync_mode, model):
|
||||||
try:
|
try:
|
||||||
from litellm.tests.test_streaming import streaming_format_tests
|
from litellm.tests.test_streaming import streaming_format_tests
|
||||||
|
|
|
@ -3144,7 +3144,6 @@ async def test_azure_astreaming_and_function_calling():
|
||||||
|
|
||||||
|
|
||||||
def test_completion_claude_3_function_call_with_streaming():
|
def test_completion_claude_3_function_call_with_streaming():
|
||||||
litellm.set_verbose = True
|
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
|
@ -3827,6 +3826,65 @@ def test_unit_test_custom_stream_wrapper_function_call():
|
||||||
assert len(new_model.choices[0].delta.tool_calls) > 0
|
assert len(new_model.choices[0].delta.tool_calls) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_unit_test_perplexity_citations_chunk():
|
||||||
|
"""
|
||||||
|
Test if model returns a tool call, the finish reason is correctly set to 'tool_calls'
|
||||||
|
"""
|
||||||
|
from litellm.types.llms.openai import ChatCompletionDeltaChunk
|
||||||
|
|
||||||
|
litellm.set_verbose = False
|
||||||
|
delta: ChatCompletionDeltaChunk = {
|
||||||
|
"content": "B",
|
||||||
|
"role": "assistant",
|
||||||
|
}
|
||||||
|
chunk = {
|
||||||
|
"id": "xxx",
|
||||||
|
"model": "llama-3.1-sonar-small-128k-online",
|
||||||
|
"created": 1725494279,
|
||||||
|
"usage": {"prompt_tokens": 15, "completion_tokens": 1, "total_tokens": 16},
|
||||||
|
"citations": [
|
||||||
|
"https://x.com/bizzabo?lang=ur",
|
||||||
|
"https://apps.apple.com/my/app/bizzabo/id408705047",
|
||||||
|
"https://www.bizzabo.com/blog/maximize-event-data-strategies-for-success",
|
||||||
|
],
|
||||||
|
"object": "chat.completion",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": None,
|
||||||
|
"message": {"role": "assistant", "content": "B"},
|
||||||
|
"delta": delta,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
chunk = litellm.ModelResponse(**chunk, stream=True)
|
||||||
|
|
||||||
|
completion_stream = ModelResponseIterator(model_response=chunk)
|
||||||
|
|
||||||
|
response = litellm.CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
custom_llm_provider="cached_response",
|
||||||
|
logging_obj=litellm.litellm_core_utils.litellm_logging.Logging(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey"}],
|
||||||
|
stream=True,
|
||||||
|
call_type="completion",
|
||||||
|
start_time=time.time(),
|
||||||
|
litellm_call_id="12345",
|
||||||
|
function_id="1245",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
for response_chunk in response:
|
||||||
|
if response_chunk.choices[0].delta.content is not None:
|
||||||
|
print(
|
||||||
|
f"response_chunk.choices[0].delta.content: {response_chunk.choices[0].delta.content}"
|
||||||
|
)
|
||||||
|
assert "citations" in response_chunk
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model",
|
||||||
[
|
[
|
||||||
|
|
|
@ -1297,3 +1297,7 @@ class CustomStreamingDecoder:
|
||||||
self, iterator: Iterator[bytes]
|
self, iterator: Iterator[bytes]
|
||||||
) -> Iterator[Optional[Union[GenericStreamingChunk, StreamingChatCompletionChunk]]]:
|
) -> Iterator[Optional[Union[GenericStreamingChunk, StreamingChatCompletionChunk]]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class StandardPassThroughResponseObject(TypedDict):
|
||||||
|
response: str
|
||||||
|
|
|
@ -9845,6 +9845,9 @@ class CustomStreamWrapper:
|
||||||
model_response.system_fingerprint = (
|
model_response.system_fingerprint = (
|
||||||
original_chunk.system_fingerprint
|
original_chunk.system_fingerprint
|
||||||
)
|
)
|
||||||
|
model_response.citations = getattr(
|
||||||
|
original_chunk, "citations", None
|
||||||
|
)
|
||||||
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
|
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
|
||||||
if self.sent_first_chunk is False:
|
if self.sent_first_chunk is False:
|
||||||
model_response.choices[0].delta["role"] = "assistant"
|
model_response.choices[0].delta["role"] = "assistant"
|
||||||
|
@ -10460,6 +10463,8 @@ class TextCompletionStreamWrapper:
|
||||||
def mock_completion_streaming_obj(
|
def mock_completion_streaming_obj(
|
||||||
model_response, mock_response, model, n: Optional[int] = None
|
model_response, mock_response, model, n: Optional[int] = None
|
||||||
):
|
):
|
||||||
|
if isinstance(mock_response, litellm.MockException):
|
||||||
|
raise mock_response
|
||||||
for i in range(0, len(mock_response), 3):
|
for i in range(0, len(mock_response), 3):
|
||||||
completion_obj = Delta(role="assistant", content=mock_response[i : i + 3])
|
completion_obj = Delta(role="assistant", content=mock_response[i : i + 3])
|
||||||
if n is None:
|
if n is None:
|
||||||
|
@ -10481,6 +10486,8 @@ def mock_completion_streaming_obj(
|
||||||
async def async_mock_completion_streaming_obj(
|
async def async_mock_completion_streaming_obj(
|
||||||
model_response, mock_response, model, n: Optional[int] = None
|
model_response, mock_response, model, n: Optional[int] = None
|
||||||
):
|
):
|
||||||
|
if isinstance(mock_response, litellm.MockException):
|
||||||
|
raise mock_response
|
||||||
for i in range(0, len(mock_response), 3):
|
for i in range(0, len(mock_response), 3):
|
||||||
completion_obj = Delta(role="assistant", content=mock_response[i : i + 3])
|
completion_obj = Delta(role="assistant", content=mock_response[i : i + 3])
|
||||||
if n is None:
|
if n is None:
|
||||||
|
|
|
@ -3408,6 +3408,15 @@
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"amazon.titan-text-premier-v1:0": {
|
||||||
|
"max_tokens": 32000,
|
||||||
|
"max_input_tokens": 42000,
|
||||||
|
"max_output_tokens": 32000,
|
||||||
|
"input_cost_per_token": 0.0000005,
|
||||||
|
"output_cost_per_token": 0.0000015,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
"amazon.titan-embed-text-v1": {
|
"amazon.titan-embed-text-v1": {
|
||||||
"max_tokens": 8192,
|
"max_tokens": 8192,
|
||||||
"max_input_tokens": 8192,
|
"max_input_tokens": 8192,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue