forked from phoenix/litellm-mirror
Merge branch 'msabramo/pydantic_replace_root_validator_with_model_validator' into msabramo/fix-pydantic-warnings
This commit is contained in:
commit
73541b1f17
53 changed files with 1835 additions and 288 deletions
|
@ -17,7 +17,7 @@ Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTeleme
|
||||||
- [Logging to Sentry](#logging-proxy-inputoutput---sentry)
|
- [Logging to Sentry](#logging-proxy-inputoutput---sentry)
|
||||||
- [Logging to Traceloop (OpenTelemetry)](#logging-proxy-inputoutput-traceloop-opentelemetry)
|
- [Logging to Traceloop (OpenTelemetry)](#logging-proxy-inputoutput-traceloop-opentelemetry)
|
||||||
- [Logging to Athina](#logging-proxy-inputoutput-athina)
|
- [Logging to Athina](#logging-proxy-inputoutput-athina)
|
||||||
- [Moderation with Azure Content-Safety](#moderation-with-azure-content-safety)
|
- [(BETA) Moderation with Azure Content-Safety](#moderation-with-azure-content-safety)
|
||||||
|
|
||||||
## Custom Callback Class [Async]
|
## Custom Callback Class [Async]
|
||||||
Use this when you want to run custom callbacks in `python`
|
Use this when you want to run custom callbacks in `python`
|
||||||
|
@ -1039,7 +1039,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
## Moderation with Azure Content Safety
|
## (BETA) Moderation with Azure Content Safety
|
||||||
|
|
||||||
[Azure Content-Safety](https://azure.microsoft.com/en-us/products/ai-services/ai-content-safety) is a Microsoft Azure service that provides content moderation APIs to detect potential offensive, harmful, or risky content in text.
|
[Azure Content-Safety](https://azure.microsoft.com/en-us/products/ai-services/ai-content-safety) is a Microsoft Azure service that provides content moderation APIs to detect potential offensive, harmful, or risky content in text.
|
||||||
|
|
||||||
|
|
|
@ -110,7 +110,7 @@ general_settings:
|
||||||
admin_jwt_scope: "litellm-proxy-admin"
|
admin_jwt_scope: "litellm-proxy-admin"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Advanced - Spend Tracking (User / Team / Org)
|
## Advanced - Spend Tracking (End-Users / Internal Users / Team / Org)
|
||||||
|
|
||||||
Set the field in the jwt token, which corresponds to a litellm user / team / org.
|
Set the field in the jwt token, which corresponds to a litellm user / team / org.
|
||||||
|
|
||||||
|
@ -123,6 +123,7 @@ general_settings:
|
||||||
team_id_jwt_field: "client_id" # 👈 CAN BE ANY FIELD
|
team_id_jwt_field: "client_id" # 👈 CAN BE ANY FIELD
|
||||||
user_id_jwt_field: "sub" # 👈 CAN BE ANY FIELD
|
user_id_jwt_field: "sub" # 👈 CAN BE ANY FIELD
|
||||||
org_id_jwt_field: "org_id" # 👈 CAN BE ANY FIELD
|
org_id_jwt_field: "org_id" # 👈 CAN BE ANY FIELD
|
||||||
|
end_user_id_jwt_field: "customer_id" # 👈 CAN BE ANY FIELD
|
||||||
```
|
```
|
||||||
|
|
||||||
Expected JWT:
|
Expected JWT:
|
||||||
|
@ -131,7 +132,7 @@ Expected JWT:
|
||||||
{
|
{
|
||||||
"client_id": "my-unique-team",
|
"client_id": "my-unique-team",
|
||||||
"sub": "my-unique-user",
|
"sub": "my-unique-user",
|
||||||
"org_id": "my-unique-org"
|
"org_id": "my-unique-org",
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,6 @@ from litellm.caching import DualCache
|
||||||
|
|
||||||
from typing import Literal, Union
|
from typing import Literal, Union
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,8 +18,6 @@ import traceback
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
|
|
@ -743,6 +743,7 @@ from .llms.sagemaker import SagemakerConfig
|
||||||
from .llms.ollama import OllamaConfig
|
from .llms.ollama import OllamaConfig
|
||||||
from .llms.ollama_chat import OllamaChatConfig
|
from .llms.ollama_chat import OllamaChatConfig
|
||||||
from .llms.maritalk import MaritTalkConfig
|
from .llms.maritalk import MaritTalkConfig
|
||||||
|
from .llms.bedrock_httpx import AmazonCohereChatConfig
|
||||||
from .llms.bedrock import (
|
from .llms.bedrock import (
|
||||||
AmazonTitanConfig,
|
AmazonTitanConfig,
|
||||||
AmazonAI21Config,
|
AmazonAI21Config,
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success + failure, log events to aispend.io
|
# On success + failure, log events to aispend.io
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
|
|
@ -8,8 +8,6 @@ from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
from typing import Literal, Union
|
from typing import Literal, Union
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,8 +16,6 @@ import traceback
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
|
|
@ -6,8 +6,6 @@ from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
from typing import Literal, Union, Optional
|
from typing import Literal, Union, Optional
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,6 @@
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
|
|
@ -3,8 +3,6 @@
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
|
|
@ -3,8 +3,6 @@
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Langfuse
|
# On success, logs events to Langfuse
|
||||||
import dotenv, os
|
import os
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import copy
|
import copy
|
||||||
import traceback
|
import traceback
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
@ -474,7 +472,29 @@ class LangFuseLogger:
|
||||||
}
|
}
|
||||||
|
|
||||||
if supports_prompt:
|
if supports_prompt:
|
||||||
generation_params["prompt"] = clean_metadata.pop("prompt", None)
|
user_prompt = clean_metadata.pop("prompt", None)
|
||||||
|
if user_prompt is None:
|
||||||
|
pass
|
||||||
|
elif isinstance(user_prompt, dict):
|
||||||
|
from langfuse.model import (
|
||||||
|
TextPromptClient,
|
||||||
|
ChatPromptClient,
|
||||||
|
Prompt_Text,
|
||||||
|
Prompt_Chat,
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_prompt.get("type", "") == "chat":
|
||||||
|
_prompt_chat = Prompt_Chat(**user_prompt)
|
||||||
|
generation_params["prompt"] = ChatPromptClient(
|
||||||
|
prompt=_prompt_chat
|
||||||
|
)
|
||||||
|
elif user_prompt.get("type", "") == "text":
|
||||||
|
_prompt_text = Prompt_Text(**user_prompt)
|
||||||
|
generation_params["prompt"] = TextPromptClient(
|
||||||
|
prompt=_prompt_text
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
generation_params["prompt"] = user_prompt
|
||||||
|
|
||||||
if output is not None and isinstance(output, str) and level == "ERROR":
|
if output is not None and isinstance(output, str) and level == "ERROR":
|
||||||
generation_params["status_message"] = output
|
generation_params["status_message"] = output
|
||||||
|
|
|
@ -3,8 +3,6 @@
|
||||||
import dotenv, os # type: ignore
|
import dotenv, os # type: ignore
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import asyncio
|
import asyncio
|
||||||
import types
|
import types
|
||||||
|
|
|
@ -2,13 +2,10 @@
|
||||||
# On success + failure, log events to lunary.ai
|
# On success + failure, log events to lunary.ai
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
import traceback
|
import traceback
|
||||||
import dotenv
|
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import packaging
|
import packaging
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
|
||||||
|
|
||||||
|
|
||||||
# convert to {completion: xx, tokens: xx}
|
# convert to {completion: xx, tokens: xx}
|
||||||
def parse_usage(usage):
|
def parse_usage(usage):
|
||||||
|
@ -79,14 +76,16 @@ class LunaryLogger:
|
||||||
version = importlib.metadata.version("lunary")
|
version = importlib.metadata.version("lunary")
|
||||||
# if version < 0.1.43 then raise ImportError
|
# if version < 0.1.43 then raise ImportError
|
||||||
if packaging.version.Version(version) < packaging.version.Version("0.1.43"):
|
if packaging.version.Version(version) < packaging.version.Version("0.1.43"):
|
||||||
print(
|
print( # noqa
|
||||||
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
|
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
|
||||||
)
|
)
|
||||||
raise ImportError
|
raise ImportError
|
||||||
|
|
||||||
self.lunary_client = lunary
|
self.lunary_client = lunary
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Lunary not installed. Please install it using 'pip install lunary'")
|
print( # noqa
|
||||||
|
"Lunary not installed. Please install it using 'pip install lunary'"
|
||||||
|
) # noqa
|
||||||
raise ImportError
|
raise ImportError
|
||||||
|
|
||||||
def log_event(
|
def log_event(
|
||||||
|
|
|
@ -3,8 +3,6 @@
|
||||||
|
|
||||||
import dotenv, os, json
|
import dotenv, os, json
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
|
|
@ -4,8 +4,6 @@
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
|
|
@ -5,8 +5,6 @@
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
|
|
@ -3,8 +3,6 @@
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success + failure, log events to Supabase
|
# On success + failure, log events to Supabase
|
||||||
|
|
||||||
import dotenv, os
|
import os
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
|
|
@ -2,8 +2,6 @@
|
||||||
# Class for sending Slack Alerts #
|
# Class for sending Slack Alerts #
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||||
import litellm, threading
|
import litellm, threading
|
||||||
from typing import List, Literal, Any, Union, Optional, Dict
|
from typing import List, Literal, Any, Union, Optional, Dict
|
||||||
|
|
|
@ -3,8 +3,6 @@
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm
|
import litellm
|
||||||
|
|
|
@ -21,11 +21,11 @@ try:
|
||||||
# contains a (known) object attribute
|
# contains a (known) object attribute
|
||||||
object: Literal["chat.completion", "edit", "text_completion"]
|
object: Literal["chat.completion", "edit", "text_completion"]
|
||||||
|
|
||||||
def __getitem__(self, key: K) -> V:
|
def __getitem__(self, key: K) -> V: ... # noqa
|
||||||
... # pragma: no cover
|
|
||||||
|
|
||||||
def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
|
def get( # noqa
|
||||||
... # pragma: no cover
|
self, key: K, default: Optional[V] = None
|
||||||
|
) -> Optional[V]: ... # pragma: no cover
|
||||||
|
|
||||||
class OpenAIRequestResponseResolver:
|
class OpenAIRequestResponseResolver:
|
||||||
def __call__(
|
def __call__(
|
||||||
|
@ -173,12 +173,11 @@ except:
|
||||||
|
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Langfuse
|
# On success, logs events to Langfuse
|
||||||
import dotenv, os
|
import os
|
||||||
import requests
|
import requests
|
||||||
import requests
|
import requests
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests, copy # type: ignore
|
import requests, copy # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional, List
|
from typing import Callable, Optional, List, Union
|
||||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||||
import litellm
|
import litellm
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
@ -151,19 +151,135 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def process_streaming_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: Union[requests.Response, httpx.Response],
|
||||||
|
model_response: ModelResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: litellm.utils.Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: List,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
"""
|
||||||
|
Return stream object for tool-calling + streaming
|
||||||
|
"""
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response.text,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = response.json()
|
||||||
|
except:
|
||||||
|
raise AnthropicError(
|
||||||
|
message=response.text, status_code=response.status_code
|
||||||
|
)
|
||||||
|
text_content = ""
|
||||||
|
tool_calls = []
|
||||||
|
for content in completion_response["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
text_content += content["text"]
|
||||||
|
## TOOL CALLING
|
||||||
|
elif content["type"] == "tool_use":
|
||||||
|
tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": content["id"],
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": content["name"],
|
||||||
|
"arguments": json.dumps(content["input"]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if "error" in completion_response:
|
||||||
|
raise AnthropicError(
|
||||||
|
message=str(completion_response["error"]),
|
||||||
|
status_code=response.status_code,
|
||||||
|
)
|
||||||
|
_message = litellm.Message(
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
content=text_content or None,
|
||||||
|
)
|
||||||
|
model_response.choices[0].message = _message # type: ignore
|
||||||
|
model_response._hidden_params["original_response"] = completion_response[
|
||||||
|
"content"
|
||||||
|
] # allow user to access raw anthropic tool calling response
|
||||||
|
|
||||||
|
model_response.choices[0].finish_reason = map_finish_reason(
|
||||||
|
completion_response["stop_reason"]
|
||||||
|
)
|
||||||
|
|
||||||
|
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
|
||||||
|
# return an iterator
|
||||||
|
streaming_model_response = ModelResponse(stream=True)
|
||||||
|
streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
|
||||||
|
0
|
||||||
|
].finish_reason
|
||||||
|
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
||||||
|
streaming_choice = litellm.utils.StreamingChoices()
|
||||||
|
streaming_choice.index = model_response.choices[0].index
|
||||||
|
_tool_calls = []
|
||||||
|
print_verbose(
|
||||||
|
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
|
||||||
|
)
|
||||||
|
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
|
||||||
|
if isinstance(model_response.choices[0], litellm.Choices):
|
||||||
|
if getattr(
|
||||||
|
model_response.choices[0].message, "tool_calls", None
|
||||||
|
) is not None and isinstance(
|
||||||
|
model_response.choices[0].message.tool_calls, list
|
||||||
|
):
|
||||||
|
for tool_call in model_response.choices[0].message.tool_calls:
|
||||||
|
_tool_call = {**tool_call.dict(), "index": 0}
|
||||||
|
_tool_calls.append(_tool_call)
|
||||||
|
delta_obj = litellm.utils.Delta(
|
||||||
|
content=getattr(model_response.choices[0].message, "content", None),
|
||||||
|
role=model_response.choices[0].message.role,
|
||||||
|
tool_calls=_tool_calls,
|
||||||
|
)
|
||||||
|
streaming_choice.delta = delta_obj
|
||||||
|
streaming_model_response.choices = [streaming_choice]
|
||||||
|
completion_stream = ModelResponseIterator(
|
||||||
|
model_response=streaming_model_response
|
||||||
|
)
|
||||||
|
print_verbose(
|
||||||
|
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||||
|
)
|
||||||
|
return CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="cached_response",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise AnthropicError(
|
||||||
|
status_code=422,
|
||||||
|
message="Unprocessable response object - {}".format(response.text),
|
||||||
|
)
|
||||||
|
|
||||||
def process_response(
|
def process_response(
|
||||||
self,
|
self,
|
||||||
model,
|
model: str,
|
||||||
response,
|
response: Union[requests.Response, httpx.Response],
|
||||||
model_response,
|
model_response: ModelResponse,
|
||||||
_is_function_call,
|
stream: bool,
|
||||||
stream,
|
logging_obj: litellm.utils.Logging,
|
||||||
logging_obj,
|
optional_params: dict,
|
||||||
api_key,
|
api_key: str,
|
||||||
data,
|
data: Union[dict, str],
|
||||||
messages,
|
messages: List,
|
||||||
print_verbose,
|
print_verbose,
|
||||||
):
|
encoding,
|
||||||
|
) -> ModelResponse:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
@ -216,51 +332,6 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
completion_response["stop_reason"]
|
completion_response["stop_reason"]
|
||||||
)
|
)
|
||||||
|
|
||||||
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
|
|
||||||
if _is_function_call and stream:
|
|
||||||
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
|
|
||||||
# return an iterator
|
|
||||||
streaming_model_response = ModelResponse(stream=True)
|
|
||||||
streaming_model_response.choices[0].finish_reason = model_response.choices[
|
|
||||||
0
|
|
||||||
].finish_reason
|
|
||||||
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
|
||||||
streaming_choice = litellm.utils.StreamingChoices()
|
|
||||||
streaming_choice.index = model_response.choices[0].index
|
|
||||||
_tool_calls = []
|
|
||||||
print_verbose(
|
|
||||||
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
|
|
||||||
)
|
|
||||||
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
|
|
||||||
if isinstance(model_response.choices[0], litellm.Choices):
|
|
||||||
if getattr(
|
|
||||||
model_response.choices[0].message, "tool_calls", None
|
|
||||||
) is not None and isinstance(
|
|
||||||
model_response.choices[0].message.tool_calls, list
|
|
||||||
):
|
|
||||||
for tool_call in model_response.choices[0].message.tool_calls:
|
|
||||||
_tool_call = {**tool_call.dict(), "index": 0}
|
|
||||||
_tool_calls.append(_tool_call)
|
|
||||||
delta_obj = litellm.utils.Delta(
|
|
||||||
content=getattr(model_response.choices[0].message, "content", None),
|
|
||||||
role=model_response.choices[0].message.role,
|
|
||||||
tool_calls=_tool_calls,
|
|
||||||
)
|
|
||||||
streaming_choice.delta = delta_obj
|
|
||||||
streaming_model_response.choices = [streaming_choice]
|
|
||||||
completion_stream = ModelResponseIterator(
|
|
||||||
model_response=streaming_model_response
|
|
||||||
)
|
|
||||||
print_verbose(
|
|
||||||
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
|
||||||
)
|
|
||||||
return CustomStreamWrapper(
|
|
||||||
completion_stream=completion_stream,
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="cached_response",
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
## CALCULATING USAGE
|
## CALCULATING USAGE
|
||||||
prompt_tokens = completion_response["usage"]["input_tokens"]
|
prompt_tokens = completion_response["usage"]["input_tokens"]
|
||||||
completion_tokens = completion_response["usage"]["output_tokens"]
|
completion_tokens = completion_response["usage"]["output_tokens"]
|
||||||
|
@ -273,7 +344,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
)
|
)
|
||||||
model_response.usage = usage
|
setattr(model_response, "usage", usage) # type: ignore
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
async def acompletion_stream_function(
|
async def acompletion_stream_function(
|
||||||
|
@ -289,7 +360,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
logging_obj,
|
logging_obj,
|
||||||
stream,
|
stream,
|
||||||
_is_function_call,
|
_is_function_call,
|
||||||
data=None,
|
data: dict,
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
@ -331,29 +402,44 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
logging_obj,
|
logging_obj,
|
||||||
stream,
|
stream,
|
||||||
_is_function_call,
|
_is_function_call,
|
||||||
data=None,
|
data: dict,
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers={},
|
headers={},
|
||||||
):
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
self.async_handler = AsyncHTTPHandler(
|
self.async_handler = AsyncHTTPHandler(
|
||||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
)
|
)
|
||||||
response = await self.async_handler.post(
|
response = await self.async_handler.post(
|
||||||
api_base, headers=headers, data=json.dumps(data)
|
api_base, headers=headers, data=json.dumps(data)
|
||||||
)
|
)
|
||||||
|
if stream and _is_function_call:
|
||||||
|
return self.process_streaming_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
return self.process_response(
|
return self.process_response(
|
||||||
model=model,
|
model=model,
|
||||||
response=response,
|
response=response,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
_is_function_call=_is_function_call,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
data=data,
|
data=data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
|
@ -367,7 +453,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
acompletion=None,
|
acompletion=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
@ -526,17 +612,33 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
raise AnthropicError(
|
raise AnthropicError(
|
||||||
status_code=response.status_code, message=response.text
|
status_code=response.status_code, message=response.text
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if stream and _is_function_call:
|
||||||
|
return self.process_streaming_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
return self.process_response(
|
return self.process_response(
|
||||||
model=model,
|
model=model,
|
||||||
response=response,
|
response=response,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
_is_function_call=_is_function_call,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
data=data,
|
data=data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
def embedding(self):
|
def embedding(self):
|
||||||
|
|
|
@ -100,7 +100,7 @@ class AnthropicTextCompletion(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def process_response(
|
def _process_response(
|
||||||
self, model_response: ModelResponse, response, encoding, prompt: str, model: str
|
self, model_response: ModelResponse, response, encoding, prompt: str, model: str
|
||||||
):
|
):
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
|
@ -171,7 +171,7 @@ class AnthropicTextCompletion(BaseLLM):
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
)
|
)
|
||||||
|
|
||||||
response = self.process_response(
|
response = self._process_response(
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
response=response,
|
response=response,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
@ -330,7 +330,7 @@ class AnthropicTextCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
print_verbose(f"raw model_response: {response.text}")
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
|
||||||
response = self.process_response(
|
response = self._process_response(
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
response=response,
|
response=response,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
|
|
@ -1,12 +1,32 @@
|
||||||
## This is a template base class to be used for adding new LLM providers via API calls
|
## This is a template base class to be used for adding new LLM providers via API calls
|
||||||
import litellm
|
import litellm
|
||||||
import httpx
|
import httpx, requests
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
from litellm.utils import Logging
|
||||||
|
|
||||||
|
|
||||||
class BaseLLM:
|
class BaseLLM:
|
||||||
_client_session: Optional[httpx.Client] = None
|
_client_session: Optional[httpx.Client] = None
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: Union[requests.Response, httpx.Response],
|
||||||
|
model_response: litellm.utils.ModelResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: list,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> litellm.utils.ModelResponse:
|
||||||
|
"""
|
||||||
|
Helper function to process the response across sync + async completion calls
|
||||||
|
"""
|
||||||
|
return model_response
|
||||||
|
|
||||||
def create_client_session(self):
|
def create_client_session(self):
|
||||||
if litellm.client_session:
|
if litellm.client_session:
|
||||||
_client_session = litellm.client_session
|
_client_session = litellm.client_session
|
||||||
|
|
733
litellm/llms/bedrock_httpx.py
Normal file
733
litellm/llms/bedrock_httpx.py
Normal file
|
@ -0,0 +1,733 @@
|
||||||
|
# What is this?
|
||||||
|
## Initial implementation of calling bedrock via httpx client (allows for async calls).
|
||||||
|
## V0 - just covers cohere command-r support
|
||||||
|
|
||||||
|
import os, types
|
||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
import requests, copy # type: ignore
|
||||||
|
import time
|
||||||
|
from typing import (
|
||||||
|
Callable,
|
||||||
|
Optional,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Union,
|
||||||
|
Any,
|
||||||
|
TypedDict,
|
||||||
|
Tuple,
|
||||||
|
Iterator,
|
||||||
|
AsyncIterator,
|
||||||
|
)
|
||||||
|
from litellm.utils import (
|
||||||
|
ModelResponse,
|
||||||
|
Usage,
|
||||||
|
map_finish_reason,
|
||||||
|
CustomStreamWrapper,
|
||||||
|
Message,
|
||||||
|
Choices,
|
||||||
|
get_secret,
|
||||||
|
Logging,
|
||||||
|
)
|
||||||
|
import litellm
|
||||||
|
from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from .base import BaseLLM
|
||||||
|
import httpx # type: ignore
|
||||||
|
from .bedrock import BedrockError, convert_messages_to_prompt
|
||||||
|
from litellm.types.llms.bedrock import *
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonCohereChatConfig:
|
||||||
|
"""
|
||||||
|
Reference - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
documents: Optional[List[Document]] = None
|
||||||
|
search_queries_only: Optional[bool] = None
|
||||||
|
preamble: Optional[str] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
p: Optional[float] = None
|
||||||
|
k: Optional[float] = None
|
||||||
|
prompt_truncation: Optional[str] = None
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
presence_penalty: Optional[float] = None
|
||||||
|
seed: Optional[int] = None
|
||||||
|
return_prompt: Optional[bool] = None
|
||||||
|
stop_sequences: Optional[List[str]] = None
|
||||||
|
raw_prompting: Optional[bool] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
documents: Optional[List[Document]] = None,
|
||||||
|
search_queries_only: Optional[bool] = None,
|
||||||
|
preamble: Optional[str] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
p: Optional[float] = None,
|
||||||
|
k: Optional[float] = None,
|
||||||
|
prompt_truncation: Optional[str] = None,
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_prompt: Optional[bool] = None,
|
||||||
|
stop_sequences: Optional[str] = None,
|
||||||
|
raw_prompting: Optional[bool] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
"max_tokens",
|
||||||
|
"stream",
|
||||||
|
"stop",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"frequency_penalty",
|
||||||
|
"presence_penalty",
|
||||||
|
"seed",
|
||||||
|
"stop",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self, non_default_params: dict, optional_params: dict
|
||||||
|
) -> dict:
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "max_tokens":
|
||||||
|
optional_params["max_tokens"] = value
|
||||||
|
if param == "stream":
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "stop":
|
||||||
|
if isinstance(value, str):
|
||||||
|
value = [value]
|
||||||
|
optional_params["stop_sequences"] = value
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["p"] = value
|
||||||
|
if param == "frequency_penalty":
|
||||||
|
optional_params["frequency_penalty"] = value
|
||||||
|
if param == "presence_penalty":
|
||||||
|
optional_params["presence_penalty"] = value
|
||||||
|
if "seed":
|
||||||
|
optional_params["seed"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockLLM(BaseLLM):
|
||||||
|
"""
|
||||||
|
Example call
|
||||||
|
|
||||||
|
```
|
||||||
|
curl --location --request POST 'https://bedrock-runtime.{aws_region_name}.amazonaws.com/model/{bedrock_model_name}/invoke' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Accept: application/json' \
|
||||||
|
--user "$AWS_ACCESS_KEY_ID":"$AWS_SECRET_ACCESS_KEY" \
|
||||||
|
--aws-sigv4 "aws:amz:us-east-1:bedrock" \
|
||||||
|
--data-raw '{
|
||||||
|
"prompt": "Hi",
|
||||||
|
"temperature": 0,
|
||||||
|
"p": 0.9,
|
||||||
|
"max_tokens": 4096
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def convert_messages_to_prompt(
|
||||||
|
self, model, messages, provider, custom_prompt_dict
|
||||||
|
) -> Tuple[str, Optional[list]]:
|
||||||
|
# handle anthropic prompts and amazon titan prompts
|
||||||
|
prompt = ""
|
||||||
|
chat_history: Optional[list] = None
|
||||||
|
if provider == "anthropic" or provider == "amazon":
|
||||||
|
if model in custom_prompt_dict:
|
||||||
|
# check if the model has a registered custom prompt
|
||||||
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
|
prompt = custom_prompt(
|
||||||
|
role_dict=model_prompt_details["roles"],
|
||||||
|
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||||
|
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = prompt_factory(
|
||||||
|
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||||
|
)
|
||||||
|
elif provider == "mistral":
|
||||||
|
prompt = prompt_factory(
|
||||||
|
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||||
|
)
|
||||||
|
elif provider == "meta":
|
||||||
|
prompt = prompt_factory(
|
||||||
|
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||||
|
)
|
||||||
|
elif provider == "cohere":
|
||||||
|
prompt, chat_history = cohere_message_pt(messages=messages)
|
||||||
|
else:
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
if "role" in message:
|
||||||
|
if message["role"] == "user":
|
||||||
|
prompt += f"{message['content']}"
|
||||||
|
else:
|
||||||
|
prompt += f"{message['content']}"
|
||||||
|
else:
|
||||||
|
prompt += f"{message['content']}"
|
||||||
|
return prompt, chat_history # type: ignore
|
||||||
|
|
||||||
|
def get_credentials(
|
||||||
|
self,
|
||||||
|
aws_access_key_id: Optional[str] = None,
|
||||||
|
aws_secret_access_key: Optional[str] = None,
|
||||||
|
aws_region_name: Optional[str] = None,
|
||||||
|
aws_session_name: Optional[str] = None,
|
||||||
|
aws_profile_name: Optional[str] = None,
|
||||||
|
aws_role_name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Return a boto3.Credentials object
|
||||||
|
"""
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
## CHECK IS 'os.environ/' passed in
|
||||||
|
params_to_check: List[Optional[str]] = [
|
||||||
|
aws_access_key_id,
|
||||||
|
aws_secret_access_key,
|
||||||
|
aws_region_name,
|
||||||
|
aws_session_name,
|
||||||
|
aws_profile_name,
|
||||||
|
aws_role_name,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Iterate over parameters and update if needed
|
||||||
|
for i, param in enumerate(params_to_check):
|
||||||
|
if param and param.startswith("os.environ/"):
|
||||||
|
_v = get_secret(param)
|
||||||
|
if _v is not None and isinstance(_v, str):
|
||||||
|
params_to_check[i] = _v
|
||||||
|
# Assign updated values back to parameters
|
||||||
|
(
|
||||||
|
aws_access_key_id,
|
||||||
|
aws_secret_access_key,
|
||||||
|
aws_region_name,
|
||||||
|
aws_session_name,
|
||||||
|
aws_profile_name,
|
||||||
|
aws_role_name,
|
||||||
|
) = params_to_check
|
||||||
|
|
||||||
|
### CHECK STS ###
|
||||||
|
if aws_role_name is not None and aws_session_name is not None:
|
||||||
|
sts_client = boto3.client(
|
||||||
|
"sts",
|
||||||
|
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
|
||||||
|
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_response = sts_client.assume_role(
|
||||||
|
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||||
|
)
|
||||||
|
|
||||||
|
return sts_response["Credentials"]
|
||||||
|
elif aws_profile_name is not None: ### CHECK SESSION ###
|
||||||
|
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||||
|
client = boto3.Session(profile_name=aws_profile_name)
|
||||||
|
|
||||||
|
return client.get_credentials()
|
||||||
|
else:
|
||||||
|
session = boto3.Session(
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
region_name=aws_region_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return session.get_credentials()
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: Union[requests.Response, httpx.Response],
|
||||||
|
model_response: ModelResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: List,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> ModelResponse:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response.text,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = response.json()
|
||||||
|
except:
|
||||||
|
raise BedrockError(message=response.text, status_code=422)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_response.choices[0].message.content = completion_response["text"] # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise BedrockError(message=response.text, status_code=422)
|
||||||
|
|
||||||
|
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||||
|
prompt_tokens = int(
|
||||||
|
response.headers.get(
|
||||||
|
"x-amzn-bedrock-input-token-count",
|
||||||
|
len(encoding.encode("".join(m.get("content", "") for m in messages))),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
completion_tokens = int(
|
||||||
|
response.headers.get(
|
||||||
|
"x-amzn-bedrock-output-token-count",
|
||||||
|
len(
|
||||||
|
encoding.encode(
|
||||||
|
model_response.choices[0].message.content, # type: ignore
|
||||||
|
disallowed_special=(),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response["created"] = int(time.time())
|
||||||
|
model_response["model"] = model
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion: bool,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from botocore.auth import SigV4Auth
|
||||||
|
from botocore.awsrequest import AWSRequest
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
|
||||||
|
## SETUP ##
|
||||||
|
stream = optional_params.pop("stream", None)
|
||||||
|
|
||||||
|
## CREDENTIALS ##
|
||||||
|
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||||
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||||
|
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||||
|
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||||
|
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||||
|
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||||
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
|
"aws_bedrock_runtime_endpoint", None
|
||||||
|
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||||
|
|
||||||
|
### SET REGION NAME ###
|
||||||
|
if aws_region_name is None:
|
||||||
|
# check env #
|
||||||
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||||
|
|
||||||
|
if litellm_aws_region_name is not None and isinstance(
|
||||||
|
litellm_aws_region_name, str
|
||||||
|
):
|
||||||
|
aws_region_name = litellm_aws_region_name
|
||||||
|
|
||||||
|
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||||
|
if standard_aws_region_name is not None and isinstance(
|
||||||
|
standard_aws_region_name, str
|
||||||
|
):
|
||||||
|
aws_region_name = standard_aws_region_name
|
||||||
|
|
||||||
|
if aws_region_name is None:
|
||||||
|
aws_region_name = "us-west-2"
|
||||||
|
|
||||||
|
credentials: Credentials = self.get_credentials(
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
aws_session_name=aws_session_name,
|
||||||
|
aws_profile_name=aws_profile_name,
|
||||||
|
aws_role_name=aws_role_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
### SET RUNTIME ENDPOINT ###
|
||||||
|
endpoint_url = ""
|
||||||
|
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||||
|
if aws_bedrock_runtime_endpoint is not None and isinstance(
|
||||||
|
aws_bedrock_runtime_endpoint, str
|
||||||
|
):
|
||||||
|
endpoint_url = aws_bedrock_runtime_endpoint
|
||||||
|
elif env_aws_bedrock_runtime_endpoint and isinstance(
|
||||||
|
env_aws_bedrock_runtime_endpoint, str
|
||||||
|
):
|
||||||
|
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||||
|
else:
|
||||||
|
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||||
|
|
||||||
|
if stream is not None and stream == True:
|
||||||
|
endpoint_url = f"{endpoint_url}/model/{model}/invoke-with-response-stream"
|
||||||
|
else:
|
||||||
|
endpoint_url = f"{endpoint_url}/model/{model}/invoke"
|
||||||
|
|
||||||
|
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||||
|
|
||||||
|
provider = model.split(".")[0]
|
||||||
|
prompt, chat_history = self.convert_messages_to_prompt(
|
||||||
|
model, messages, provider, custom_prompt_dict
|
||||||
|
)
|
||||||
|
inference_params = copy.deepcopy(optional_params)
|
||||||
|
|
||||||
|
if provider == "cohere":
|
||||||
|
if model.startswith("cohere.command-r"):
|
||||||
|
## LOAD CONFIG
|
||||||
|
config = litellm.AmazonCohereChatConfig().get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in inference_params
|
||||||
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
inference_params[k] = v
|
||||||
|
_data = {"message": prompt, **inference_params}
|
||||||
|
if chat_history is not None:
|
||||||
|
_data["chat_history"] = chat_history
|
||||||
|
data = json.dumps(_data)
|
||||||
|
else:
|
||||||
|
## LOAD CONFIG
|
||||||
|
config = litellm.AmazonCohereConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in inference_params
|
||||||
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
inference_params[k] = v
|
||||||
|
if stream == True:
|
||||||
|
inference_params["stream"] = (
|
||||||
|
True # cohere requires stream = True in inference params
|
||||||
|
)
|
||||||
|
data = json.dumps({"prompt": prompt, **inference_params})
|
||||||
|
else:
|
||||||
|
raise Exception("UNSUPPORTED PROVIDER")
|
||||||
|
|
||||||
|
## COMPLETION CALL
|
||||||
|
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers = {"Content-Type": "application/json", **extra_headers}
|
||||||
|
request = AWSRequest(
|
||||||
|
method="POST", url=endpoint_url, data=data, headers=headers
|
||||||
|
)
|
||||||
|
sigv4.add_auth(request)
|
||||||
|
prepped = request.prepare()
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"api_base": prepped.url,
|
||||||
|
"headers": prepped.headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||||
|
if acompletion:
|
||||||
|
if isinstance(client, HTTPHandler):
|
||||||
|
client = None
|
||||||
|
if stream:
|
||||||
|
return self.async_streaming(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=prepped.url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=True,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=prepped.headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
) # type: ignore
|
||||||
|
### ASYNC COMPLETION
|
||||||
|
return self.async_completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=prepped.url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=False,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=prepped.headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
self.client = HTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
self.client = client
|
||||||
|
if stream is not None and stream == True:
|
||||||
|
response = self.client.post(
|
||||||
|
url=prepped.url,
|
||||||
|
headers=prepped.headers, # type: ignore
|
||||||
|
data=data,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise BedrockError(
|
||||||
|
status_code=response.status_code, message=response.text
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = AWSEventStreamDecoder()
|
||||||
|
|
||||||
|
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="bedrock",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streaming_response
|
||||||
|
|
||||||
|
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise BedrockError(status_code=error_code, message=response.text)
|
||||||
|
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_key="",
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
data: str,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
self.client = AsyncHTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
self.client = client # type: ignore
|
||||||
|
|
||||||
|
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key="",
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_streaming(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
data: str,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
self.client = AsyncHTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
self.client = client # type: ignore
|
||||||
|
|
||||||
|
response = await self.client.post(api_base, headers=headers, data=data, stream=True) # type: ignore
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise BedrockError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
|
decoder = AWSEventStreamDecoder()
|
||||||
|
|
||||||
|
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="bedrock",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streaming_response
|
||||||
|
|
||||||
|
def embedding(self, *args, **kwargs):
|
||||||
|
return super().embedding(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_response_stream_shape():
|
||||||
|
from botocore.model import ServiceModel
|
||||||
|
from botocore.loaders import Loader
|
||||||
|
|
||||||
|
loader = Loader()
|
||||||
|
bedrock_service_dict = loader.load_service_model("bedrock-runtime", "service-2")
|
||||||
|
bedrock_service_model = ServiceModel(bedrock_service_dict)
|
||||||
|
return bedrock_service_model.shape_for("ResponseStream")
|
||||||
|
|
||||||
|
|
||||||
|
class AWSEventStreamDecoder:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
from botocore.parsers import EventStreamJSONParser
|
||||||
|
|
||||||
|
self.parser = EventStreamJSONParser()
|
||||||
|
|
||||||
|
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
|
||||||
|
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
|
||||||
|
from botocore.eventstream import EventStreamBuffer
|
||||||
|
|
||||||
|
event_stream_buffer = EventStreamBuffer()
|
||||||
|
for chunk in iterator:
|
||||||
|
event_stream_buffer.add_data(chunk)
|
||||||
|
for event in event_stream_buffer:
|
||||||
|
message = self._parse_message_from_event(event)
|
||||||
|
if message:
|
||||||
|
# sse_event = ServerSentEvent(data=message, event="completion")
|
||||||
|
_data = json.loads(message)
|
||||||
|
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
|
||||||
|
text=_data.get("text", ""),
|
||||||
|
is_finished=_data.get("is_finished", False),
|
||||||
|
finish_reason=_data.get("finish_reason", ""),
|
||||||
|
)
|
||||||
|
yield streaming_chunk
|
||||||
|
|
||||||
|
async def aiter_bytes(
|
||||||
|
self, iterator: AsyncIterator[bytes]
|
||||||
|
) -> AsyncIterator[GenericStreamingChunk]:
|
||||||
|
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
|
||||||
|
from botocore.eventstream import EventStreamBuffer
|
||||||
|
|
||||||
|
event_stream_buffer = EventStreamBuffer()
|
||||||
|
async for chunk in iterator:
|
||||||
|
event_stream_buffer.add_data(chunk)
|
||||||
|
for event in event_stream_buffer:
|
||||||
|
message = self._parse_message_from_event(event)
|
||||||
|
if message:
|
||||||
|
_data = json.loads(message)
|
||||||
|
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
|
||||||
|
text=_data.get("text", ""),
|
||||||
|
is_finished=_data.get("is_finished", False),
|
||||||
|
finish_reason=_data.get("finish_reason", ""),
|
||||||
|
)
|
||||||
|
yield streaming_chunk
|
||||||
|
|
||||||
|
def _parse_message_from_event(self, event) -> Optional[str]:
|
||||||
|
response_dict = event.to_response_dict()
|
||||||
|
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
|
||||||
|
if response_dict["status_code"] != 200:
|
||||||
|
raise ValueError(f"Bad response code, expected 200: {response_dict}")
|
||||||
|
|
||||||
|
chunk = parsed_response.get("chunk")
|
||||||
|
if not chunk:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return chunk.get("bytes").decode() # type: ignore[no-any-return]
|
|
@ -58,16 +58,25 @@ class AsyncHTTPHandler:
|
||||||
|
|
||||||
class HTTPHandler:
|
class HTTPHandler:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000
|
self,
|
||||||
|
timeout: Optional[httpx.Timeout] = None,
|
||||||
|
concurrent_limit=1000,
|
||||||
|
client: Optional[httpx.Client] = None,
|
||||||
):
|
):
|
||||||
# Create a client with a connection pool
|
if timeout is None:
|
||||||
self.client = httpx.Client(
|
timeout = _DEFAULT_TIMEOUT
|
||||||
timeout=timeout,
|
|
||||||
limits=httpx.Limits(
|
if client is None:
|
||||||
max_connections=concurrent_limit,
|
# Create a client with a connection pool
|
||||||
max_keepalive_connections=concurrent_limit,
|
self.client = httpx.Client(
|
||||||
),
|
timeout=timeout,
|
||||||
)
|
limits=httpx.Limits(
|
||||||
|
max_connections=concurrent_limit,
|
||||||
|
max_keepalive_connections=concurrent_limit,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.client = client
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
# Close the client when you're done with it
|
# Close the client when you're done with it
|
||||||
|
@ -82,11 +91,15 @@ class HTTPHandler:
|
||||||
def post(
|
def post(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
data: Optional[dict] = None,
|
data: Optional[Union[dict, str]] = None,
|
||||||
params: Optional[dict] = None,
|
params: Optional[dict] = None,
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
response = self.client.post(url, data=data, params=params, headers=headers)
|
req = self.client.build_request(
|
||||||
|
"POST", url, data=data, params=params, headers=headers # type: ignore
|
||||||
|
)
|
||||||
|
response = self.client.send(req, stream=stream)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
|
|
|
@ -168,7 +168,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
logging_obj: litellm.utils.Logging,
|
logging_obj: litellm.utils.Logging,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
data: dict,
|
data: Union[dict, str],
|
||||||
messages: list,
|
messages: list,
|
||||||
print_verbose,
|
print_verbose,
|
||||||
encoding,
|
encoding,
|
||||||
|
@ -185,9 +185,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
try:
|
try:
|
||||||
completion_response = response.json()
|
completion_response = response.json()
|
||||||
except:
|
except:
|
||||||
raise PredibaseError(
|
raise PredibaseError(message=response.text, status_code=422)
|
||||||
message=response.text, status_code=response.status_code
|
|
||||||
)
|
|
||||||
if "error" in completion_response:
|
if "error" in completion_response:
|
||||||
raise PredibaseError(
|
raise PredibaseError(
|
||||||
message=str(completion_response["error"]),
|
message=str(completion_response["error"]),
|
||||||
|
@ -363,7 +361,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
if acompletion is True:
|
if acompletion == True:
|
||||||
### ASYNC STREAMING
|
### ASYNC STREAMING
|
||||||
if stream == True:
|
if stream == True:
|
||||||
return self.async_streaming(
|
return self.async_streaming(
|
||||||
|
|
122
litellm/main.py
122
litellm/main.py
|
@ -14,6 +14,7 @@ import dotenv, traceback, random, asyncio, time, contextvars
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import httpx
|
import httpx
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
from ._logging import verbose_logger
|
from ._logging import verbose_logger
|
||||||
from litellm import ( # type: ignore
|
from litellm import ( # type: ignore
|
||||||
client,
|
client,
|
||||||
|
@ -76,6 +77,7 @@ from .llms.anthropic import AnthropicChatCompletion
|
||||||
from .llms.anthropic_text import AnthropicTextCompletion
|
from .llms.anthropic_text import AnthropicTextCompletion
|
||||||
from .llms.huggingface_restapi import Huggingface
|
from .llms.huggingface_restapi import Huggingface
|
||||||
from .llms.predibase import PredibaseChatCompletion
|
from .llms.predibase import PredibaseChatCompletion
|
||||||
|
from .llms.bedrock_httpx import BedrockLLM
|
||||||
from .llms.triton import TritonChatCompletion
|
from .llms.triton import TritonChatCompletion
|
||||||
from .llms.prompt_templates.factory import (
|
from .llms.prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
|
@ -105,7 +107,6 @@ from litellm.utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
####### ENVIRONMENT VARIABLES ###################
|
####### ENVIRONMENT VARIABLES ###################
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
openai_chat_completions = OpenAIChatCompletion()
|
openai_chat_completions = OpenAIChatCompletion()
|
||||||
openai_text_completions = OpenAITextCompletion()
|
openai_text_completions = OpenAITextCompletion()
|
||||||
anthropic_chat_completions = AnthropicChatCompletion()
|
anthropic_chat_completions = AnthropicChatCompletion()
|
||||||
|
@ -115,6 +116,7 @@ azure_text_completions = AzureTextCompletion()
|
||||||
huggingface = Huggingface()
|
huggingface = Huggingface()
|
||||||
predibase_chat_completions = PredibaseChatCompletion()
|
predibase_chat_completions = PredibaseChatCompletion()
|
||||||
triton_chat_completions = TritonChatCompletion()
|
triton_chat_completions = TritonChatCompletion()
|
||||||
|
bedrock_chat_completion = BedrockLLM()
|
||||||
####### COMPLETION ENDPOINTS ################
|
####### COMPLETION ENDPOINTS ################
|
||||||
|
|
||||||
|
|
||||||
|
@ -257,7 +259,7 @@ async def acompletion(
|
||||||
- If `stream` is True, the function returns an async generator that yields completion lines.
|
- If `stream` is True, the function returns an async generator that yields completion lines.
|
||||||
"""
|
"""
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
custom_llm_provider = None
|
custom_llm_provider = kwargs.get("custom_llm_provider", None)
|
||||||
# Adjusted to use explicit arguments instead of *args and **kwargs
|
# Adjusted to use explicit arguments instead of *args and **kwargs
|
||||||
completion_kwargs = {
|
completion_kwargs = {
|
||||||
"model": model,
|
"model": model,
|
||||||
|
@ -289,9 +291,10 @@ async def acompletion(
|
||||||
"model_list": model_list,
|
"model_list": model_list,
|
||||||
"acompletion": True, # assuming this is a required parameter
|
"acompletion": True, # assuming this is a required parameter
|
||||||
}
|
}
|
||||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
if custom_llm_provider is None:
|
||||||
model=model, api_base=completion_kwargs.get("base_url", None)
|
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||||
)
|
model=model, api_base=completion_kwargs.get("base_url", None)
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
# Use a partial function to pass your keyword arguments
|
# Use a partial function to pass your keyword arguments
|
||||||
func = partial(completion, **completion_kwargs, **kwargs)
|
func = partial(completion, **completion_kwargs, **kwargs)
|
||||||
|
@ -300,9 +303,6 @@ async def acompletion(
|
||||||
ctx = contextvars.copy_context()
|
ctx = contextvars.copy_context()
|
||||||
func_with_context = partial(ctx.run, func)
|
func_with_context = partial(ctx.run, func)
|
||||||
|
|
||||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
|
||||||
model=model, api_base=kwargs.get("api_base", None)
|
|
||||||
)
|
|
||||||
if (
|
if (
|
||||||
custom_llm_provider == "openai"
|
custom_llm_provider == "openai"
|
||||||
or custom_llm_provider == "azure"
|
or custom_llm_provider == "azure"
|
||||||
|
@ -324,6 +324,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "sagemaker"
|
or custom_llm_provider == "sagemaker"
|
||||||
or custom_llm_provider == "anthropic"
|
or custom_llm_provider == "anthropic"
|
||||||
or custom_llm_provider == "predibase"
|
or custom_llm_provider == "predibase"
|
||||||
|
or (custom_llm_provider == "bedrock" and "cohere" in model)
|
||||||
or custom_llm_provider in litellm.openai_compatible_providers
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
@ -1213,11 +1214,12 @@ def completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
response = model_response
|
response = model_response
|
||||||
elif ("clarifai" in model
|
elif (
|
||||||
or custom_llm_provider == "clarifai"
|
"clarifai" in model
|
||||||
or model in litellm.clarifai_models
|
or custom_llm_provider == "clarifai"
|
||||||
):
|
or model in litellm.clarifai_models
|
||||||
clarifai_key = None
|
):
|
||||||
|
clarifai_key = None
|
||||||
clarifai_key = (
|
clarifai_key = (
|
||||||
api_key
|
api_key
|
||||||
or litellm.clarifai_key
|
or litellm.clarifai_key
|
||||||
|
@ -1225,14 +1227,14 @@ def completion(
|
||||||
or get_secret("CLARIFAI_API_KEY")
|
or get_secret("CLARIFAI_API_KEY")
|
||||||
or get_secret("CLARIFAI_API_TOKEN")
|
or get_secret("CLARIFAI_API_TOKEN")
|
||||||
)
|
)
|
||||||
|
|
||||||
api_base = (
|
api_base = (
|
||||||
api_base
|
api_base
|
||||||
or litellm.api_base
|
or litellm.api_base
|
||||||
or get_secret("CLARIFAI_API_BASE")
|
or get_secret("CLARIFAI_API_BASE")
|
||||||
or "https://api.clarifai.com/v2"
|
or "https://api.clarifai.com/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
model_response = clarifai.completion(
|
model_response = clarifai.completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -1249,7 +1251,7 @@ def completion(
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -1258,7 +1260,7 @@ def completion(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
original_response=model_response,
|
original_response=model_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
if optional_params.get("stream", False) or acompletion == True:
|
if optional_params.get("stream", False) or acompletion == True:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
|
@ -1976,41 +1978,59 @@ def completion(
|
||||||
elif custom_llm_provider == "bedrock":
|
elif custom_llm_provider == "bedrock":
|
||||||
# boto3 reads keys from .env
|
# boto3 reads keys from .env
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
response = bedrock.completion(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
|
||||||
model_response=model_response,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
optional_params=optional_params,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
logger_fn=logger_fn,
|
|
||||||
encoding=encoding,
|
|
||||||
logging_obj=logging,
|
|
||||||
extra_headers=extra_headers,
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if "cohere" in model:
|
||||||
"stream" in optional_params
|
response = bedrock_chat_completion.completion(
|
||||||
and optional_params["stream"] == True
|
model=model,
|
||||||
and not isinstance(response, CustomStreamWrapper)
|
messages=messages,
|
||||||
):
|
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||||
# don't try to access stream object,
|
model_response=model_response,
|
||||||
if "ai21" in model:
|
print_verbose=print_verbose,
|
||||||
response = CustomStreamWrapper(
|
optional_params=optional_params,
|
||||||
response,
|
litellm_params=litellm_params,
|
||||||
model,
|
logger_fn=logger_fn,
|
||||||
custom_llm_provider="bedrock",
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
)
|
extra_headers=extra_headers,
|
||||||
else:
|
timeout=timeout,
|
||||||
response = CustomStreamWrapper(
|
acompletion=acompletion,
|
||||||
iter(response),
|
)
|
||||||
model,
|
else:
|
||||||
custom_llm_provider="bedrock",
|
response = bedrock.completion(
|
||||||
logging_obj=logging,
|
model=model,
|
||||||
)
|
messages=messages,
|
||||||
|
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging,
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
"stream" in optional_params
|
||||||
|
and optional_params["stream"] == True
|
||||||
|
and not isinstance(response, CustomStreamWrapper)
|
||||||
|
):
|
||||||
|
# don't try to access stream object,
|
||||||
|
if "ai21" in model:
|
||||||
|
response = CustomStreamWrapper(
|
||||||
|
response,
|
||||||
|
model,
|
||||||
|
custom_llm_provider="bedrock",
|
||||||
|
logging_obj=logging,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = CustomStreamWrapper(
|
||||||
|
iter(response),
|
||||||
|
model,
|
||||||
|
custom_llm_provider="bedrock",
|
||||||
|
logging_obj=logging,
|
||||||
|
)
|
||||||
|
|
||||||
if optional_params.get("stream", False):
|
if optional_params.get("stream", False):
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
|
|
@ -2644,6 +2644,24 @@
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"cohere.command-r-plus-v1:0": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.0000030,
|
||||||
|
"output_cost_per_token": 0.000015,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"cohere.command-r-v1:0": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.0000005,
|
||||||
|
"output_cost_per_token": 0.0000015,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
"cohere.embed-english-v3": {
|
"cohere.embed-english-v3": {
|
||||||
"max_tokens": 512,
|
"max_tokens": 512,
|
||||||
"max_input_tokens": 512,
|
"max_input_tokens": 512,
|
||||||
|
|
|
@ -6,6 +6,15 @@ import uuid
|
||||||
import json
|
import json
|
||||||
from litellm.types.router import UpdateRouterConfig
|
from litellm.types.router import UpdateRouterConfig
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pydantic import model_validator # pydantic v2
|
||||||
|
except ImportError:
|
||||||
|
from pydantic import root_validator # pydantic v1
|
||||||
|
|
||||||
|
def model_validator(mode):
|
||||||
|
pre = mode == "before"
|
||||||
|
return root_validator(pre=pre)
|
||||||
|
|
||||||
|
|
||||||
def hash_token(token: str):
|
def hash_token(token: str):
|
||||||
import hashlib
|
import hashlib
|
||||||
|
@ -183,8 +192,14 @@ class LiteLLM_JWTAuth(LiteLLMBase):
|
||||||
|
|
||||||
admin_jwt_scope: str = "litellm_proxy_admin"
|
admin_jwt_scope: str = "litellm_proxy_admin"
|
||||||
admin_allowed_routes: List[
|
admin_allowed_routes: List[
|
||||||
Literal["openai_routes", "info_routes", "management_routes"]
|
Literal[
|
||||||
] = ["management_routes"]
|
"openai_routes",
|
||||||
|
"info_routes",
|
||||||
|
"management_routes",
|
||||||
|
"spend_tracking_routes",
|
||||||
|
"global_spend_tracking_routes",
|
||||||
|
]
|
||||||
|
] = ["management_routes", "spend_tracking_routes", "global_spend_tracking_routes"]
|
||||||
team_jwt_scope: str = "litellm_team"
|
team_jwt_scope: str = "litellm_team"
|
||||||
team_id_jwt_field: str = "client_id"
|
team_id_jwt_field: str = "client_id"
|
||||||
team_allowed_routes: List[
|
team_allowed_routes: List[
|
||||||
|
@ -217,7 +232,7 @@ class LiteLLMPromptInjectionParams(LiteLLMBase):
|
||||||
llm_api_system_prompt: Optional[str] = None
|
llm_api_system_prompt: Optional[str] = None
|
||||||
llm_api_fail_call_string: Optional[str] = None
|
llm_api_fail_call_string: Optional[str] = None
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@model_validator(mode="before")
|
||||||
def check_llm_api_params(cls, values):
|
def check_llm_api_params(cls, values):
|
||||||
llm_api_check = values.get("llm_api_check")
|
llm_api_check = values.get("llm_api_check")
|
||||||
if llm_api_check is True:
|
if llm_api_check is True:
|
||||||
|
@ -309,7 +324,7 @@ class ModelInfo(LiteLLMBase):
|
||||||
protected_namespaces = (),
|
protected_namespaces = (),
|
||||||
)
|
)
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@model_validator(mode="before")
|
||||||
def set_model_info(cls, values):
|
def set_model_info(cls, values):
|
||||||
if values.get("id") is None:
|
if values.get("id") is None:
|
||||||
values.update({"id": str(uuid.uuid4())})
|
values.update({"id": str(uuid.uuid4())})
|
||||||
|
@ -339,7 +354,7 @@ class ModelParams(LiteLLMBase):
|
||||||
protected_namespaces = (),
|
protected_namespaces = (),
|
||||||
)
|
)
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@model_validator(mode="before")
|
||||||
def set_model_info(cls, values):
|
def set_model_info(cls, values):
|
||||||
if values.get("model_info") is None:
|
if values.get("model_info") is None:
|
||||||
values.update({"model_info": ModelInfo()})
|
values.update({"model_info": ModelInfo()})
|
||||||
|
@ -387,7 +402,7 @@ class GenerateKeyResponse(GenerateKeyRequest):
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
token_id: Optional[str] = None
|
token_id: Optional[str] = None
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@model_validator(mode="before")
|
||||||
def set_model_info(cls, values):
|
def set_model_info(cls, values):
|
||||||
if values.get("token") is not None:
|
if values.get("token") is not None:
|
||||||
values.update({"key": values.get("token")})
|
values.update({"key": values.get("token")})
|
||||||
|
@ -457,7 +472,7 @@ class UpdateUserRequest(GenerateRequestBase):
|
||||||
user_role: Optional[str] = None
|
user_role: Optional[str] = None
|
||||||
max_budget: Optional[float] = None
|
max_budget: Optional[float] = None
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@model_validator(mode="before")
|
||||||
def check_user_info(cls, values):
|
def check_user_info(cls, values):
|
||||||
if values.get("user_id") is None and values.get("user_email") is None:
|
if values.get("user_id") is None and values.get("user_email") is None:
|
||||||
raise ValueError("Either user id or user email must be provided")
|
raise ValueError("Either user id or user email must be provided")
|
||||||
|
@ -477,7 +492,7 @@ class NewEndUserRequest(LiteLLMBase):
|
||||||
None # if no equivalent model in allowed region - default all requests to this model
|
None # if no equivalent model in allowed region - default all requests to this model
|
||||||
)
|
)
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@model_validator(mode="before")
|
||||||
def check_user_info(cls, values):
|
def check_user_info(cls, values):
|
||||||
if values.get("max_budget") is not None and values.get("budget_id") is not None:
|
if values.get("max_budget") is not None and values.get("budget_id") is not None:
|
||||||
raise ValueError("Set either 'max_budget' or 'budget_id', not both.")
|
raise ValueError("Set either 'max_budget' or 'budget_id', not both.")
|
||||||
|
@ -490,7 +505,7 @@ class Member(LiteLLMBase):
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
user_email: Optional[str] = None
|
user_email: Optional[str] = None
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@model_validator(mode="before")
|
||||||
def check_user_info(cls, values):
|
def check_user_info(cls, values):
|
||||||
if values.get("user_id") is None and values.get("user_email") is None:
|
if values.get("user_id") is None and values.get("user_email") is None:
|
||||||
raise ValueError("Either user id or user email must be provided")
|
raise ValueError("Either user id or user email must be provided")
|
||||||
|
@ -536,7 +551,7 @@ class TeamMemberDeleteRequest(LiteLLMBase):
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
user_email: Optional[str] = None
|
user_email: Optional[str] = None
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@model_validator(mode="before")
|
||||||
def check_user_info(cls, values):
|
def check_user_info(cls, values):
|
||||||
if values.get("user_id") is None and values.get("user_email") is None:
|
if values.get("user_id") is None and values.get("user_email") is None:
|
||||||
raise ValueError("Either user id or user email must be provided")
|
raise ValueError("Either user id or user email must be provided")
|
||||||
|
@ -574,7 +589,7 @@ class LiteLLM_TeamTable(TeamBase):
|
||||||
protected_namespaces = (),
|
protected_namespaces = (),
|
||||||
)
|
)
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@model_validator(mode="before")
|
||||||
def set_model_info(cls, values):
|
def set_model_info(cls, values):
|
||||||
dict_fields = [
|
dict_fields = [
|
||||||
"metadata",
|
"metadata",
|
||||||
|
@ -873,7 +888,7 @@ class UserAPIKeyAuth(
|
||||||
user_role: Optional[Literal["proxy_admin", "app_owner", "app_user"]] = None
|
user_role: Optional[Literal["proxy_admin", "app_owner", "app_user"]] = None
|
||||||
allowed_model_region: Optional[Literal["eu"]] = None
|
allowed_model_region: Optional[Literal["eu"]] = None
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@model_validator(mode="before")
|
||||||
def check_api_key(cls, values):
|
def check_api_key(cls, values):
|
||||||
if values.get("api_key") is not None:
|
if values.get("api_key") is not None:
|
||||||
values.update({"token": hash_token(values.get("api_key"))})
|
values.update({"token": hash_token(values.get("api_key"))})
|
||||||
|
@ -900,7 +915,7 @@ class LiteLLM_UserTable(LiteLLMBase):
|
||||||
tpm_limit: Optional[int] = None
|
tpm_limit: Optional[int] = None
|
||||||
rpm_limit: Optional[int] = None
|
rpm_limit: Optional[int] = None
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@model_validator(mode="before")
|
||||||
def set_model_info(cls, values):
|
def set_model_info(cls, values):
|
||||||
if values.get("spend") is None:
|
if values.get("spend") is None:
|
||||||
values.update({"spend": 0.0})
|
values.update({"spend": 0.0})
|
||||||
|
@ -922,7 +937,7 @@ class LiteLLM_EndUserTable(LiteLLMBase):
|
||||||
default_model: Optional[str] = None
|
default_model: Optional[str] = None
|
||||||
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@model_validator(mode="before")
|
||||||
def set_model_info(cls, values):
|
def set_model_info(cls, values):
|
||||||
if values.get("spend") is None:
|
if values.get("spend") is None:
|
||||||
values.update({"spend": 0.0})
|
values.update({"spend": 0.0})
|
||||||
|
|
|
@ -1,10 +1,7 @@
|
||||||
from litellm.proxy._types import UserAPIKeyAuth, GenerateKeyRequest
|
from litellm.proxy._types import UserAPIKeyAuth, GenerateKeyRequest
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from dotenv import load_dotenv
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
|
|
||||||
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -425,7 +425,7 @@ async def user_api_key_auth(
|
||||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||||
)
|
)
|
||||||
if is_allowed == False:
|
if is_allowed == False:
|
||||||
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes
|
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore
|
||||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||||
|
@ -2263,10 +2263,18 @@ class ProxyConfig:
|
||||||
_PROXY_AzureContentSafety,
|
_PROXY_AzureContentSafety,
|
||||||
)
|
)
|
||||||
|
|
||||||
azure_content_safety_params = litellm_settings["azure_content_safety_params"]
|
azure_content_safety_params = litellm_settings[
|
||||||
|
"azure_content_safety_params"
|
||||||
|
]
|
||||||
for k, v in azure_content_safety_params.items():
|
for k, v in azure_content_safety_params.items():
|
||||||
if v is not None and isinstance(v, str) and v.startswith("os.environ/"):
|
if (
|
||||||
azure_content_safety_params[k] = litellm.get_secret(v)
|
v is not None
|
||||||
|
and isinstance(v, str)
|
||||||
|
and v.startswith("os.environ/")
|
||||||
|
):
|
||||||
|
azure_content_safety_params[k] = (
|
||||||
|
litellm.get_secret(v)
|
||||||
|
)
|
||||||
|
|
||||||
azure_content_safety_obj = _PROXY_AzureContentSafety(
|
azure_content_safety_obj = _PROXY_AzureContentSafety(
|
||||||
**azure_content_safety_params,
|
**azure_content_safety_params,
|
||||||
|
|
|
@ -1507,22 +1507,30 @@ class Router:
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
original_exception = e
|
original_exception = e
|
||||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
|
"""
|
||||||
if (
|
Retry Logic
|
||||||
isinstance(original_exception, litellm.ContextWindowExceededError)
|
|
||||||
and context_window_fallbacks is not None
|
"""
|
||||||
) or (
|
_healthy_deployments = await self._async_get_healthy_deployments(
|
||||||
isinstance(original_exception, openai.RateLimitError)
|
model=kwargs.get("model"),
|
||||||
and fallbacks is not None
|
)
|
||||||
):
|
|
||||||
raise original_exception
|
|
||||||
### RETRY
|
|
||||||
|
|
||||||
_timeout = self._router_should_retry(
|
# raises an exception if this error should not be retries
|
||||||
|
self.should_retry_this_error(
|
||||||
|
error=e,
|
||||||
|
healthy_deployments=_healthy_deployments,
|
||||||
|
context_window_fallbacks=context_window_fallbacks,
|
||||||
|
)
|
||||||
|
|
||||||
|
# decides how long to sleep before retry
|
||||||
|
_timeout = self._time_to_sleep_before_retry(
|
||||||
e=original_exception,
|
e=original_exception,
|
||||||
remaining_retries=num_retries,
|
remaining_retries=num_retries,
|
||||||
num_retries=num_retries,
|
num_retries=num_retries,
|
||||||
|
healthy_deployments=_healthy_deployments,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# sleeps for the length of the timeout
|
||||||
await asyncio.sleep(_timeout)
|
await asyncio.sleep(_timeout)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -1556,10 +1564,14 @@ class Router:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||||
remaining_retries = num_retries - current_attempt
|
remaining_retries = num_retries - current_attempt
|
||||||
_timeout = self._router_should_retry(
|
_healthy_deployments = await self._async_get_healthy_deployments(
|
||||||
|
model=kwargs.get("model"),
|
||||||
|
)
|
||||||
|
_timeout = self._time_to_sleep_before_retry(
|
||||||
e=original_exception,
|
e=original_exception,
|
||||||
remaining_retries=remaining_retries,
|
remaining_retries=remaining_retries,
|
||||||
num_retries=num_retries,
|
num_retries=num_retries,
|
||||||
|
healthy_deployments=_healthy_deployments,
|
||||||
)
|
)
|
||||||
await asyncio.sleep(_timeout)
|
await asyncio.sleep(_timeout)
|
||||||
try:
|
try:
|
||||||
|
@ -1568,6 +1580,40 @@ class Router:
|
||||||
pass
|
pass
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
|
||||||
|
def should_retry_this_error(
|
||||||
|
self,
|
||||||
|
error: Exception,
|
||||||
|
healthy_deployments: Optional[List] = None,
|
||||||
|
context_window_fallbacks: Optional[List] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None
|
||||||
|
|
||||||
|
2. raise an exception for RateLimitError if
|
||||||
|
- there are no fallbacks
|
||||||
|
- there are no healthy deployments in the same model group
|
||||||
|
"""
|
||||||
|
|
||||||
|
_num_healthy_deployments = 0
|
||||||
|
if healthy_deployments is not None and isinstance(healthy_deployments, list):
|
||||||
|
_num_healthy_deployments = len(healthy_deployments)
|
||||||
|
|
||||||
|
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
|
||||||
|
if (
|
||||||
|
isinstance(error, litellm.ContextWindowExceededError)
|
||||||
|
and context_window_fallbacks is None
|
||||||
|
):
|
||||||
|
raise error
|
||||||
|
|
||||||
|
# Error we should only retry if there are other deployments
|
||||||
|
if isinstance(error, openai.RateLimitError) or isinstance(
|
||||||
|
error, openai.AuthenticationError
|
||||||
|
):
|
||||||
|
if _num_healthy_deployments <= 0:
|
||||||
|
raise error
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def function_with_fallbacks(self, *args, **kwargs):
|
def function_with_fallbacks(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Try calling the function_with_retries
|
Try calling the function_with_retries
|
||||||
|
@ -1656,12 +1702,27 @@ class Router:
|
||||||
raise e
|
raise e
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
|
||||||
def _router_should_retry(
|
def _time_to_sleep_before_retry(
|
||||||
self, e: Exception, remaining_retries: int, num_retries: int
|
self,
|
||||||
|
e: Exception,
|
||||||
|
remaining_retries: int,
|
||||||
|
num_retries: int,
|
||||||
|
healthy_deployments: Optional[List] = None,
|
||||||
) -> Union[int, float]:
|
) -> Union[int, float]:
|
||||||
"""
|
"""
|
||||||
Calculate back-off, then retry
|
Calculate back-off, then retry
|
||||||
|
|
||||||
|
It should instantly retry only when:
|
||||||
|
1. there are healthy deployments in the same model group
|
||||||
|
2. there are fallbacks for the completion call
|
||||||
"""
|
"""
|
||||||
|
if (
|
||||||
|
healthy_deployments is not None
|
||||||
|
and isinstance(healthy_deployments, list)
|
||||||
|
and len(healthy_deployments) > 0
|
||||||
|
):
|
||||||
|
return 0
|
||||||
|
|
||||||
if hasattr(e, "response") and hasattr(e.response, "headers"):
|
if hasattr(e, "response") and hasattr(e.response, "headers"):
|
||||||
timeout = litellm._calculate_retry_after(
|
timeout = litellm._calculate_retry_after(
|
||||||
remaining_retries=remaining_retries,
|
remaining_retries=remaining_retries,
|
||||||
|
@ -1698,23 +1759,29 @@ class Router:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
original_exception = e
|
original_exception = e
|
||||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
|
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
|
||||||
if (
|
_healthy_deployments = self._get_healthy_deployments(
|
||||||
isinstance(original_exception, litellm.ContextWindowExceededError)
|
model=kwargs.get("model"),
|
||||||
and context_window_fallbacks is not None
|
)
|
||||||
) or (
|
|
||||||
isinstance(original_exception, openai.RateLimitError)
|
# raises an exception if this error should not be retries
|
||||||
and fallbacks is not None
|
self.should_retry_this_error(
|
||||||
):
|
error=e,
|
||||||
raise original_exception
|
healthy_deployments=_healthy_deployments,
|
||||||
## LOGGING
|
context_window_fallbacks=context_window_fallbacks,
|
||||||
if num_retries > 0:
|
)
|
||||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
|
||||||
### RETRY
|
# decides how long to sleep before retry
|
||||||
_timeout = self._router_should_retry(
|
_timeout = self._time_to_sleep_before_retry(
|
||||||
e=original_exception,
|
e=original_exception,
|
||||||
remaining_retries=num_retries,
|
remaining_retries=num_retries,
|
||||||
num_retries=num_retries,
|
num_retries=num_retries,
|
||||||
|
healthy_deployments=_healthy_deployments,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
if num_retries > 0:
|
||||||
|
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
||||||
|
|
||||||
time.sleep(_timeout)
|
time.sleep(_timeout)
|
||||||
for current_attempt in range(num_retries):
|
for current_attempt in range(num_retries):
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
|
@ -1728,11 +1795,15 @@ class Router:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||||
|
_healthy_deployments = self._get_healthy_deployments(
|
||||||
|
model=kwargs.get("model"),
|
||||||
|
)
|
||||||
remaining_retries = num_retries - current_attempt
|
remaining_retries = num_retries - current_attempt
|
||||||
_timeout = self._router_should_retry(
|
_timeout = self._time_to_sleep_before_retry(
|
||||||
e=e,
|
e=e,
|
||||||
remaining_retries=remaining_retries,
|
remaining_retries=remaining_retries,
|
||||||
num_retries=num_retries,
|
num_retries=num_retries,
|
||||||
|
healthy_deployments=_healthy_deployments,
|
||||||
)
|
)
|
||||||
time.sleep(_timeout)
|
time.sleep(_timeout)
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
@ -1935,6 +2006,47 @@ class Router:
|
||||||
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||||
return cooldown_models
|
return cooldown_models
|
||||||
|
|
||||||
|
def _get_healthy_deployments(self, model: str):
|
||||||
|
_all_deployments: list = []
|
||||||
|
try:
|
||||||
|
_, _all_deployments = self._common_checks_available_deployment( # type: ignore
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
if type(_all_deployments) == dict:
|
||||||
|
return []
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
unhealthy_deployments = self._get_cooldown_deployments()
|
||||||
|
healthy_deployments: list = []
|
||||||
|
for deployment in _all_deployments:
|
||||||
|
if deployment["model_info"]["id"] in unhealthy_deployments:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
healthy_deployments.append(deployment)
|
||||||
|
|
||||||
|
return healthy_deployments
|
||||||
|
|
||||||
|
async def _async_get_healthy_deployments(self, model: str):
|
||||||
|
_all_deployments: list = []
|
||||||
|
try:
|
||||||
|
_, _all_deployments = self._common_checks_available_deployment( # type: ignore
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
if type(_all_deployments) == dict:
|
||||||
|
return []
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
unhealthy_deployments = await self._async_get_cooldown_deployments()
|
||||||
|
healthy_deployments: list = []
|
||||||
|
for deployment in _all_deployments:
|
||||||
|
if deployment["model_info"]["id"] in unhealthy_deployments:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
healthy_deployments.append(deployment)
|
||||||
|
return healthy_deployments
|
||||||
|
|
||||||
def routing_strategy_pre_call_checks(self, deployment: dict):
|
def routing_strategy_pre_call_checks(self, deployment: dict):
|
||||||
"""
|
"""
|
||||||
Mimics 'async_routing_strategy_pre_call_checks'
|
Mimics 'async_routing_strategy_pre_call_checks'
|
||||||
|
|
|
@ -8,8 +8,6 @@
|
||||||
|
|
||||||
import dotenv, os, requests, random # type: ignore
|
import dotenv, os, requests, random # type: ignore
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
|
@ -1,12 +1,11 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# picks based on response time (for streaming, this is time to first token)
|
# picks based on response time (for streaming, this is time to first token)
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
import dotenv, os, requests, random # type: ignore
|
import os, requests, random # type: ignore
|
||||||
from typing import Optional, Union, List, Dict
|
from typing import Optional, Union, List, Dict
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import random
|
import random
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
|
@ -5,8 +5,6 @@ import dotenv, os, requests, random # type: ignore
|
||||||
from typing import Optional, Union, List, Dict
|
from typing import Optional, Union, List, Dict
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import random
|
import random
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
|
@ -4,8 +4,6 @@
|
||||||
import dotenv, os, requests, random
|
import dotenv, os, requests, random
|
||||||
from typing import Optional, Union, List, Dict
|
from typing import Optional, Union, List, Dict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
from litellm import token_counter
|
from litellm import token_counter
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
|
|
@ -5,8 +5,6 @@ import dotenv, os, requests, random
|
||||||
from typing import Optional, Union, List, Dict
|
from typing import Optional, Union, List, Dict
|
||||||
import datetime as datetime_og
|
import datetime as datetime_og
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback, asyncio, httpx
|
import traceback, asyncio, httpx
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import token_counter
|
from litellm import token_counter
|
||||||
|
|
|
@ -312,7 +312,7 @@ async def test_langfuse_logging_metadata(langfuse_client):
|
||||||
metadata["existing_trace_id"] = trace_id
|
metadata["existing_trace_id"] = trace_id
|
||||||
|
|
||||||
langfuse_client.flush()
|
langfuse_client.flush()
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
# Tests the metadata filtering and the override of the output to be the last generation
|
# Tests the metadata filtering and the override of the output to be the last generation
|
||||||
for trace_id, generation_ids in trace_identifiers.items():
|
for trace_id, generation_ids in trace_identifiers.items():
|
||||||
|
|
|
@ -14,7 +14,6 @@ sys.path.insert(
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
import pytest
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.proxy.hooks.azure_content_safety import _PROXY_AzureContentSafety
|
|
||||||
from litellm import Router, mock_completion
|
from litellm import Router, mock_completion
|
||||||
from litellm.proxy.utils import ProxyLogging
|
from litellm.proxy.utils import ProxyLogging
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
@ -22,11 +21,14 @@ from litellm.caching import DualCache
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skip(reason="beta feature - local testing is failing")
|
||||||
async def test_strict_input_filtering_01():
|
async def test_strict_input_filtering_01():
|
||||||
"""
|
"""
|
||||||
- have a response with a filtered input
|
- have a response with a filtered input
|
||||||
- call the pre call hook
|
- call the pre call hook
|
||||||
"""
|
"""
|
||||||
|
from litellm.proxy.hooks.azure_content_safety import _PROXY_AzureContentSafety
|
||||||
|
|
||||||
azure_content_safety = _PROXY_AzureContentSafety(
|
azure_content_safety = _PROXY_AzureContentSafety(
|
||||||
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||||
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||||
|
@ -54,11 +56,14 @@ async def test_strict_input_filtering_01():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skip(reason="beta feature - local testing is failing")
|
||||||
async def test_strict_input_filtering_02():
|
async def test_strict_input_filtering_02():
|
||||||
"""
|
"""
|
||||||
- have a response with a filtered input
|
- have a response with a filtered input
|
||||||
- call the pre call hook
|
- call the pre call hook
|
||||||
"""
|
"""
|
||||||
|
from litellm.proxy.hooks.azure_content_safety import _PROXY_AzureContentSafety
|
||||||
|
|
||||||
azure_content_safety = _PROXY_AzureContentSafety(
|
azure_content_safety = _PROXY_AzureContentSafety(
|
||||||
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||||
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||||
|
@ -81,11 +86,14 @@ async def test_strict_input_filtering_02():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skip(reason="beta feature - local testing is failing")
|
||||||
async def test_loose_input_filtering_01():
|
async def test_loose_input_filtering_01():
|
||||||
"""
|
"""
|
||||||
- have a response with a filtered input
|
- have a response with a filtered input
|
||||||
- call the pre call hook
|
- call the pre call hook
|
||||||
"""
|
"""
|
||||||
|
from litellm.proxy.hooks.azure_content_safety import _PROXY_AzureContentSafety
|
||||||
|
|
||||||
azure_content_safety = _PROXY_AzureContentSafety(
|
azure_content_safety = _PROXY_AzureContentSafety(
|
||||||
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||||
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||||
|
@ -108,11 +116,14 @@ async def test_loose_input_filtering_01():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skip(reason="beta feature - local testing is failing")
|
||||||
async def test_loose_input_filtering_02():
|
async def test_loose_input_filtering_02():
|
||||||
"""
|
"""
|
||||||
- have a response with a filtered input
|
- have a response with a filtered input
|
||||||
- call the pre call hook
|
- call the pre call hook
|
||||||
"""
|
"""
|
||||||
|
from litellm.proxy.hooks.azure_content_safety import _PROXY_AzureContentSafety
|
||||||
|
|
||||||
azure_content_safety = _PROXY_AzureContentSafety(
|
azure_content_safety = _PROXY_AzureContentSafety(
|
||||||
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||||
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||||
|
@ -135,11 +146,14 @@ async def test_loose_input_filtering_02():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skip(reason="beta feature - local testing is failing")
|
||||||
async def test_strict_output_filtering_01():
|
async def test_strict_output_filtering_01():
|
||||||
"""
|
"""
|
||||||
- have a response with a filtered output
|
- have a response with a filtered output
|
||||||
- call the post call hook
|
- call the post call hook
|
||||||
"""
|
"""
|
||||||
|
from litellm.proxy.hooks.azure_content_safety import _PROXY_AzureContentSafety
|
||||||
|
|
||||||
azure_content_safety = _PROXY_AzureContentSafety(
|
azure_content_safety = _PROXY_AzureContentSafety(
|
||||||
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||||
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||||
|
@ -172,11 +186,14 @@ async def test_strict_output_filtering_01():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skip(reason="beta feature - local testing is failing")
|
||||||
async def test_strict_output_filtering_02():
|
async def test_strict_output_filtering_02():
|
||||||
"""
|
"""
|
||||||
- have a response with a filtered output
|
- have a response with a filtered output
|
||||||
- call the post call hook
|
- call the post call hook
|
||||||
"""
|
"""
|
||||||
|
from litellm.proxy.hooks.azure_content_safety import _PROXY_AzureContentSafety
|
||||||
|
|
||||||
azure_content_safety = _PROXY_AzureContentSafety(
|
azure_content_safety = _PROXY_AzureContentSafety(
|
||||||
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||||
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||||
|
@ -204,11 +221,14 @@ async def test_strict_output_filtering_02():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skip(reason="beta feature - local testing is failing")
|
||||||
async def test_loose_output_filtering_01():
|
async def test_loose_output_filtering_01():
|
||||||
"""
|
"""
|
||||||
- have a response with a filtered output
|
- have a response with a filtered output
|
||||||
- call the post call hook
|
- call the post call hook
|
||||||
"""
|
"""
|
||||||
|
from litellm.proxy.hooks.azure_content_safety import _PROXY_AzureContentSafety
|
||||||
|
|
||||||
azure_content_safety = _PROXY_AzureContentSafety(
|
azure_content_safety = _PROXY_AzureContentSafety(
|
||||||
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||||
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||||
|
@ -236,11 +256,14 @@ async def test_loose_output_filtering_01():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skip(reason="beta feature - local testing is failing")
|
||||||
async def test_loose_output_filtering_02():
|
async def test_loose_output_filtering_02():
|
||||||
"""
|
"""
|
||||||
- have a response with a filtered output
|
- have a response with a filtered output
|
||||||
- call the post call hook
|
- call the post call hook
|
||||||
"""
|
"""
|
||||||
|
from litellm.proxy.hooks.azure_content_safety import _PROXY_AzureContentSafety
|
||||||
|
|
||||||
azure_content_safety = _PROXY_AzureContentSafety(
|
azure_content_safety = _PROXY_AzureContentSafety(
|
||||||
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
endpoint=os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT"),
|
||||||
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
api_key=os.getenv("AZURE_CONTENT_SAFETY_API_KEY"),
|
||||||
|
|
|
@ -11,7 +11,15 @@ sys.path.insert(
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
import pytest
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import embedding, completion, acompletion, acreate, completion_cost, Timeout, ModelResponse
|
from litellm import (
|
||||||
|
embedding,
|
||||||
|
completion,
|
||||||
|
acompletion,
|
||||||
|
acreate,
|
||||||
|
completion_cost,
|
||||||
|
Timeout,
|
||||||
|
ModelResponse,
|
||||||
|
)
|
||||||
from litellm import RateLimitError
|
from litellm import RateLimitError
|
||||||
|
|
||||||
# litellm.num_retries = 3
|
# litellm.num_retries = 3
|
||||||
|
@ -20,6 +28,7 @@ litellm.success_callback = []
|
||||||
user_message = "Write a short poem about the sky"
|
user_message = "Write a short poem about the sky"
|
||||||
messages = [{"content": user_message, "role": "user"}]
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def reset_callbacks():
|
def reset_callbacks():
|
||||||
print("\npytest fixture - resetting callbacks")
|
print("\npytest fixture - resetting callbacks")
|
||||||
|
@ -27,28 +36,29 @@ def reset_callbacks():
|
||||||
litellm._async_success_callback = []
|
litellm._async_success_callback = []
|
||||||
litellm.failure_callback = []
|
litellm.failure_callback = []
|
||||||
litellm.callbacks = []
|
litellm.callbacks = []
|
||||||
|
|
||||||
|
|
||||||
def test_completion_clarifai_claude_2_1():
|
def test_completion_clarifai_claude_2_1():
|
||||||
print("calling clarifai claude completion")
|
print("calling clarifai claude completion")
|
||||||
import os
|
import os
|
||||||
|
|
||||||
clarifai_pat = os.environ["CLARIFAI_API_KEY"]
|
clarifai_pat = os.environ["CLARIFAI_API_KEY"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = completion(
|
response = completion(
|
||||||
model="clarifai/anthropic.completion.claude-2_1",
|
model="clarifai/anthropic.completion.claude-2_1",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
except RateLimitError:
|
except RateLimitError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occured: {e}")
|
pytest.fail(f"Error occured: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_completion_clarifai_mistral_large():
|
def test_completion_clarifai_mistral_large():
|
||||||
try:
|
try:
|
||||||
|
@ -66,7 +76,8 @@ def test_completion_clarifai_mistral_large():
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
def test_async_completion_clarifai():
|
def test_async_completion_clarifai():
|
||||||
import asyncio
|
import asyncio
|
||||||
|
@ -88,6 +99,5 @@ def test_async_completion_clarifai():
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred: {e}")
|
pytest.fail(f"An exception occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(test_get_response())
|
asyncio.run(test_get_response())
|
||||||
|
|
|
@ -1305,7 +1305,7 @@ def test_hf_classifier_task():
|
||||||
|
|
||||||
########################### End of Hugging Face Tests ##############################################
|
########################### End of Hugging Face Tests ##############################################
|
||||||
# def test_completion_hf_api():
|
# def test_completion_hf_api():
|
||||||
# # failing on circle ci commenting out
|
# # failing on circle-ci commenting out
|
||||||
# try:
|
# try:
|
||||||
# user_message = "write some code to find the sum of two numbers"
|
# user_message = "write some code to find the sum of two numbers"
|
||||||
# messages = [{ "content": user_message,"role": "user"}]
|
# messages = [{ "content": user_message,"role": "user"}]
|
||||||
|
@ -2584,6 +2584,69 @@ def test_completion_chat_sagemaker_mistral():
|
||||||
# test_completion_chat_sagemaker_mistral()
|
# test_completion_chat_sagemaker_mistral()
|
||||||
|
|
||||||
|
|
||||||
|
def response_format_tests(response: litellm.ModelResponse):
|
||||||
|
assert isinstance(response.id, str)
|
||||||
|
assert response.id != ""
|
||||||
|
|
||||||
|
assert isinstance(response.object, str)
|
||||||
|
assert response.object != ""
|
||||||
|
|
||||||
|
assert isinstance(response.created, int)
|
||||||
|
|
||||||
|
assert isinstance(response.model, str)
|
||||||
|
assert response.model != ""
|
||||||
|
|
||||||
|
assert isinstance(response.choices, list)
|
||||||
|
assert len(response.choices) == 1
|
||||||
|
choice = response.choices[0]
|
||||||
|
assert isinstance(choice, litellm.Choices)
|
||||||
|
assert isinstance(choice.get("index"), int)
|
||||||
|
|
||||||
|
message = choice.get("message")
|
||||||
|
assert isinstance(message, litellm.Message)
|
||||||
|
assert isinstance(message.get("role"), str)
|
||||||
|
assert message.get("role") != ""
|
||||||
|
assert isinstance(message.get("content"), str)
|
||||||
|
assert message.get("content") != ""
|
||||||
|
|
||||||
|
assert choice.get("logprobs") is None
|
||||||
|
assert isinstance(choice.get("finish_reason"), str)
|
||||||
|
assert choice.get("finish_reason") != ""
|
||||||
|
|
||||||
|
assert isinstance(response.usage, litellm.Usage) # type: ignore
|
||||||
|
assert isinstance(response.usage.prompt_tokens, int) # type: ignore
|
||||||
|
assert isinstance(response.usage.completion_tokens, int) # type: ignore
|
||||||
|
assert isinstance(response.usage.total_tokens, int) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_bedrock_command_r(sync_mode):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
if sync_mode:
|
||||||
|
response = completion(
|
||||||
|
model="bedrock/cohere.command-r-plus-v1:0",
|
||||||
|
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
|
|
||||||
|
response_format_tests(response=response)
|
||||||
|
else:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="bedrock/cohere.command-r-plus-v1:0",
|
||||||
|
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
|
|
||||||
|
print(f"response: {response}")
|
||||||
|
response_format_tests(response=response)
|
||||||
|
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
|
||||||
def test_completion_bedrock_titan_null_response():
|
def test_completion_bedrock_titan_null_response():
|
||||||
try:
|
try:
|
||||||
response = completion(
|
response = completion(
|
||||||
|
@ -3236,6 +3299,7 @@ def test_completion_watsonx():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_completion_stream_watsonx():
|
def test_completion_stream_watsonx():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
model_name = "watsonx/ibm/granite-13b-chat-v2"
|
model_name = "watsonx/ibm/granite-13b-chat-v2"
|
||||||
|
@ -3245,7 +3309,7 @@ def test_completion_stream_watsonx():
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stop=["stop"],
|
stop=["stop"],
|
||||||
max_tokens=20,
|
max_tokens=20,
|
||||||
stream=True
|
stream=True,
|
||||||
)
|
)
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
@ -3318,6 +3382,7 @@ async def test_acompletion_watsonx():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_acompletion_stream_watsonx():
|
async def test_acompletion_stream_watsonx():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
@ -3329,7 +3394,7 @@ async def test_acompletion_stream_watsonx():
|
||||||
messages=messages,
|
messages=messages,
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
max_tokens=80,
|
max_tokens=80,
|
||||||
stream=True
|
stream=True,
|
||||||
)
|
)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
|
|
|
@ -83,7 +83,6 @@ def test_async_fallbacks(caplog):
|
||||||
# - error request, falling back notice, success notice
|
# - error request, falling back notice, success notice
|
||||||
expected_logs = [
|
expected_logs = [
|
||||||
"litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}} \nModel: gpt-3.5-turbo\nAPI Base: https://api.openai.com\nMessages: [{'content': 'Hello, how are you?', 'role': 'user'}]\nmodel_group: gpt-3.5-turbo\n\ndeployment: gpt-3.5-turbo\n\x1b[0m",
|
"litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}} \nModel: gpt-3.5-turbo\nAPI Base: https://api.openai.com\nMessages: [{'content': 'Hello, how are you?', 'role': 'user'}]\nmodel_group: gpt-3.5-turbo\n\ndeployment: gpt-3.5-turbo\n\x1b[0m",
|
||||||
"litellm.acompletion(model=None)\x1b[31m Exception No deployments available for selected model, passed model=gpt-3.5-turbo\x1b[0m",
|
|
||||||
"Falling back to model_group = azure/gpt-3.5-turbo",
|
"Falling back to model_group = azure/gpt-3.5-turbo",
|
||||||
"litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m",
|
"litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m",
|
||||||
]
|
]
|
||||||
|
|
|
@ -269,7 +269,7 @@ def test_sync_fallbacks_embeddings():
|
||||||
response = router.embedding(**kwargs)
|
response = router.embedding(**kwargs)
|
||||||
print(f"customHandler.previous_models: {customHandler.previous_models}")
|
print(f"customHandler.previous_models: {customHandler.previous_models}")
|
||||||
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
|
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
|
||||||
assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback
|
assert customHandler.previous_models == 1 # 1 init call, 2 retries, 1 fallback
|
||||||
router.reset()
|
router.reset()
|
||||||
except litellm.Timeout as e:
|
except litellm.Timeout as e:
|
||||||
pass
|
pass
|
||||||
|
@ -323,7 +323,7 @@ async def test_async_fallbacks_embeddings():
|
||||||
await asyncio.sleep(
|
await asyncio.sleep(
|
||||||
0.05
|
0.05
|
||||||
) # allow a delay as success_callbacks are on a separate thread
|
) # allow a delay as success_callbacks are on a separate thread
|
||||||
assert customHandler.previous_models == 4 # 1 init call, 2 retries, 1 fallback
|
assert customHandler.previous_models == 1 # 1 init call with a bad key
|
||||||
router.reset()
|
router.reset()
|
||||||
except litellm.Timeout as e:
|
except litellm.Timeout as e:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -12,6 +12,7 @@ sys.path.insert(
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
import openai, httpx
|
||||||
|
|
||||||
|
|
||||||
class MyCustomHandler(CustomLogger):
|
class MyCustomHandler(CustomLogger):
|
||||||
|
@ -191,8 +192,8 @@ async def test_dynamic_router_retry_policy(model_group):
|
||||||
from litellm.router import RetryPolicy
|
from litellm.router import RetryPolicy
|
||||||
|
|
||||||
model_group_retry_policy = {
|
model_group_retry_policy = {
|
||||||
"gpt-3.5-turbo": RetryPolicy(ContentPolicyViolationErrorRetries=0),
|
"gpt-3.5-turbo": RetryPolicy(ContentPolicyViolationErrorRetries=2),
|
||||||
"bad-model": RetryPolicy(AuthenticationErrorRetries=4),
|
"bad-model": RetryPolicy(AuthenticationErrorRetries=0),
|
||||||
}
|
}
|
||||||
|
|
||||||
router = litellm.Router(
|
router = litellm.Router(
|
||||||
|
@ -205,6 +206,33 @@ async def test_dynamic_router_retry_policy(model_group):
|
||||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
"api_base": os.getenv("AZURE_API_BASE"),
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
},
|
},
|
||||||
|
"model_info": {
|
||||||
|
"id": "model-0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
"model_info": {
|
||||||
|
"id": "model-1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
"model_info": {
|
||||||
|
"id": "model-2",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model_name": "bad-model", # openai model name
|
"model_name": "bad-model", # openai model name
|
||||||
|
@ -240,6 +268,264 @@ async def test_dynamic_router_retry_policy(model_group):
|
||||||
print("customHandler.previous_models: ", customHandler.previous_models)
|
print("customHandler.previous_models: ", customHandler.previous_models)
|
||||||
|
|
||||||
if model_group == "bad-model":
|
if model_group == "bad-model":
|
||||||
assert customHandler.previous_models == 4
|
|
||||||
elif model_group == "gpt-3.5-turbo":
|
|
||||||
assert customHandler.previous_models == 0
|
assert customHandler.previous_models == 0
|
||||||
|
elif model_group == "gpt-3.5-turbo":
|
||||||
|
assert customHandler.previous_models == 2
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Unit Tests for Router Retry Logic
|
||||||
|
|
||||||
|
Test 1. Retry Rate Limit Errors when there are other healthy deployments
|
||||||
|
|
||||||
|
Test 2. Do not retry rate limit errors when - there are no fallbacks and no healthy deployments
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
rate_limit_error = openai.RateLimitError(
|
||||||
|
message="Rate limit exceeded",
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=429,
|
||||||
|
request=httpx.Request(method="POST", url="https://api.openai.com/v1"),
|
||||||
|
),
|
||||||
|
body={
|
||||||
|
"error": {
|
||||||
|
"type": "rate_limit_exceeded",
|
||||||
|
"param": None,
|
||||||
|
"code": "rate_limit_exceeded",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_retry_rate_limit_error_with_healthy_deployments():
|
||||||
|
"""
|
||||||
|
Test 1. It SHOULD retry when there is a rate limit error and len(healthy_deployments) > 0
|
||||||
|
"""
|
||||||
|
healthy_deployments = [
|
||||||
|
"deployment1",
|
||||||
|
"deployment2",
|
||||||
|
] # multiple healthy deployments mocked up
|
||||||
|
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
try:
|
||||||
|
response = router.should_retry_this_error(
|
||||||
|
error=rate_limit_error, healthy_deployments=healthy_deployments
|
||||||
|
)
|
||||||
|
print("response from should_retry_this_error: ", response)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(
|
||||||
|
"Should not have raised an error, since there are healthy deployments. Raises",
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_do_not_retry_rate_limit_error_with_no_fallbacks_and_no_healthy_deployments():
|
||||||
|
"""
|
||||||
|
Test 2. It SHOULD NOT Retry, when healthy_deployments is [] and fallbacks is None
|
||||||
|
"""
|
||||||
|
healthy_deployments = []
|
||||||
|
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
try:
|
||||||
|
response = router.should_retry_this_error(
|
||||||
|
error=rate_limit_error, healthy_deployments=healthy_deployments
|
||||||
|
)
|
||||||
|
assert response != True, "Should have raised RateLimitError"
|
||||||
|
except openai.RateLimitError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_raise_context_window_exceeded_error():
|
||||||
|
"""
|
||||||
|
Retry Context Window Exceeded Error, when context_window_fallbacks is not None
|
||||||
|
"""
|
||||||
|
context_window_error = litellm.ContextWindowExceededError(
|
||||||
|
message="Context window exceeded",
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=400,
|
||||||
|
request=httpx.Request(method="POST", url="https://api.openai.com/v1"),
|
||||||
|
),
|
||||||
|
llm_provider="azure",
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
)
|
||||||
|
context_window_fallbacks = [{"gpt-3.5-turbo": ["azure/chatgpt-v-2"]}]
|
||||||
|
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = router.should_retry_this_error(
|
||||||
|
error=context_window_error,
|
||||||
|
healthy_deployments=None,
|
||||||
|
context_window_fallbacks=context_window_fallbacks,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
response == True
|
||||||
|
), "Should not have raised exception since we have context window fallbacks"
|
||||||
|
|
||||||
|
|
||||||
|
def test_raise_context_window_exceeded_error_no_retry():
|
||||||
|
"""
|
||||||
|
Do not Retry Context Window Exceeded Error, when context_window_fallbacks is None
|
||||||
|
"""
|
||||||
|
context_window_error = litellm.ContextWindowExceededError(
|
||||||
|
message="Context window exceeded",
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=400,
|
||||||
|
request=httpx.Request(method="POST", url="https://api.openai.com/v1"),
|
||||||
|
),
|
||||||
|
llm_provider="azure",
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
)
|
||||||
|
context_window_fallbacks = None
|
||||||
|
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = router.should_retry_this_error(
|
||||||
|
error=context_window_error,
|
||||||
|
healthy_deployments=None,
|
||||||
|
context_window_fallbacks=context_window_fallbacks,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
response != True
|
||||||
|
), "Should have raised exception since we do not have context window fallbacks"
|
||||||
|
except litellm.ContextWindowExceededError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
## Unit test time to back off for router retries
|
||||||
|
|
||||||
|
"""
|
||||||
|
1. Timeout is 0.0 when RateLimit Error and healthy deployments are > 0
|
||||||
|
2. Timeout is 0.0 when RateLimit Error and fallbacks are > 0
|
||||||
|
3. Timeout is > 0.0 when RateLimit Error and healthy deployments == 0 and fallbacks == None
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_timeout_for_rate_limit_error_with_healthy_deployments():
|
||||||
|
"""
|
||||||
|
Test 1. Timeout is 0.0 when RateLimit Error and healthy deployments are > 0
|
||||||
|
"""
|
||||||
|
healthy_deployments = [
|
||||||
|
"deployment1",
|
||||||
|
"deployment2",
|
||||||
|
] # multiple healthy deployments mocked up
|
||||||
|
fallbacks = None
|
||||||
|
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
_timeout = router._time_to_sleep_before_retry(
|
||||||
|
e=rate_limit_error,
|
||||||
|
remaining_retries=4,
|
||||||
|
num_retries=4,
|
||||||
|
healthy_deployments=healthy_deployments,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"timeout=",
|
||||||
|
_timeout,
|
||||||
|
"error is rate_limit_error and there are healthy deployments=",
|
||||||
|
healthy_deployments,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert _timeout == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_timeout_for_rate_limit_error_with_no_healthy_deployments():
|
||||||
|
"""
|
||||||
|
Test 2. Timeout is > 0.0 when RateLimit Error and healthy deployments == 0
|
||||||
|
"""
|
||||||
|
healthy_deployments = []
|
||||||
|
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
_timeout = router._time_to_sleep_before_retry(
|
||||||
|
e=rate_limit_error,
|
||||||
|
remaining_retries=4,
|
||||||
|
num_retries=4,
|
||||||
|
healthy_deployments=healthy_deployments,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"timeout=",
|
||||||
|
_timeout,
|
||||||
|
"error is rate_limit_error and there are no healthy deployments",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert _timeout > 0.0
|
||||||
|
|
|
@ -132,12 +132,15 @@ def test_post_call_rule_streaming():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_post_call_processing_error_async_response():
|
@pytest.mark.asyncio
|
||||||
response = asyncio.run(
|
async def test_post_call_processing_error_async_response():
|
||||||
acompletion(
|
try:
|
||||||
|
response = await acompletion(
|
||||||
model="command-nightly", # Just used as an example
|
model="command-nightly", # Just used as an example
|
||||||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||||
api_base="https://openai-proxy.berriai.repl.co", # Just used as an example
|
api_base="https://openai-proxy.berriai.repl.co", # Just used as an example
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider="openai",
|
||||||
)
|
)
|
||||||
)
|
pytest.fail("This call should have failed")
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
|
@ -983,6 +983,64 @@ def test_vertex_ai_stream():
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bedrock_cohere_command_r_streaming(sync_mode):
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
if sync_mode:
|
||||||
|
final_chunk: Optional[litellm.ModelResponse] = None
|
||||||
|
response: litellm.CustomStreamWrapper = completion( # type: ignore
|
||||||
|
model="bedrock/cohere.command-r-plus-v1:0",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10, # type: ignore
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
complete_response = ""
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
has_finish_reason = False
|
||||||
|
for idx, chunk in enumerate(response):
|
||||||
|
final_chunk = chunk
|
||||||
|
chunk, finished = streaming_format_tests(idx, chunk)
|
||||||
|
if finished:
|
||||||
|
has_finish_reason = True
|
||||||
|
break
|
||||||
|
complete_response += chunk
|
||||||
|
if has_finish_reason == False:
|
||||||
|
raise Exception("finish reason not set")
|
||||||
|
if complete_response.strip() == "":
|
||||||
|
raise Exception("Empty response received")
|
||||||
|
else:
|
||||||
|
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
|
||||||
|
model="bedrock/cohere.command-r-plus-v1:0",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=100, # type: ignore
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
complete_response = ""
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
has_finish_reason = False
|
||||||
|
idx = 0
|
||||||
|
final_chunk: Optional[litellm.ModelResponse] = None
|
||||||
|
async for chunk in response:
|
||||||
|
final_chunk = chunk
|
||||||
|
chunk, finished = streaming_format_tests(idx, chunk)
|
||||||
|
if finished:
|
||||||
|
has_finish_reason = True
|
||||||
|
break
|
||||||
|
complete_response += chunk
|
||||||
|
idx += 1
|
||||||
|
if has_finish_reason == False:
|
||||||
|
raise Exception("finish reason not set")
|
||||||
|
if complete_response.strip() == "":
|
||||||
|
raise Exception("Empty response received")
|
||||||
|
print(f"completion_response: {complete_response}\n\nFinalChunk: {final_chunk}")
|
||||||
|
except RateLimitError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_bedrock_claude_3_streaming():
|
def test_bedrock_claude_3_streaming():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
63
litellm/types/llms/bedrock.py
Normal file
63
litellm/types/llms/bedrock.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
from typing import TypedDict, Any, Union, Optional
|
||||||
|
import json
|
||||||
|
from typing_extensions import (
|
||||||
|
Self,
|
||||||
|
Protocol,
|
||||||
|
TypeGuard,
|
||||||
|
override,
|
||||||
|
get_origin,
|
||||||
|
runtime_checkable,
|
||||||
|
Required,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GenericStreamingChunk(TypedDict):
|
||||||
|
text: Required[str]
|
||||||
|
is_finished: Required[bool]
|
||||||
|
finish_reason: Required[str]
|
||||||
|
|
||||||
|
|
||||||
|
class Document(TypedDict):
|
||||||
|
title: str
|
||||||
|
snippet: str
|
||||||
|
|
||||||
|
|
||||||
|
class ServerSentEvent:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
event: Optional[str] = None,
|
||||||
|
data: Optional[str] = None,
|
||||||
|
id: Optional[str] = None,
|
||||||
|
retry: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
if data is None:
|
||||||
|
data = ""
|
||||||
|
|
||||||
|
self._id = id
|
||||||
|
self._data = data
|
||||||
|
self._event = event or None
|
||||||
|
self._retry = retry
|
||||||
|
|
||||||
|
@property
|
||||||
|
def event(self) -> Optional[str]:
|
||||||
|
return self._event
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self) -> Optional[str]:
|
||||||
|
return self._id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retry(self) -> Optional[int]:
|
||||||
|
return self._retry
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self) -> str:
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
def json(self) -> Any:
|
||||||
|
return json.loads(self.data)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
|
|
@ -132,7 +132,6 @@ MAX_THREADS = 100
|
||||||
|
|
||||||
# Create a ThreadPoolExecutor
|
# Create a ThreadPoolExecutor
|
||||||
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
|
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
sentry_sdk_instance = None
|
sentry_sdk_instance = None
|
||||||
capture_exception = None
|
capture_exception = None
|
||||||
add_breadcrumb = None
|
add_breadcrumb = None
|
||||||
|
@ -8217,10 +8216,7 @@ def exception_type(
|
||||||
+ "Exception"
|
+ "Exception"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if "This model's maximum context length is" in error_str:
|
||||||
"This model's maximum context length is" in error_str
|
|
||||||
or "Request too large" in error_str
|
|
||||||
):
|
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise ContextWindowExceededError(
|
raise ContextWindowExceededError(
|
||||||
message=f"{exception_provider} - {message} {extra_information}",
|
message=f"{exception_provider} - {message} {extra_information}",
|
||||||
|
@ -8261,6 +8257,13 @@ def exception_type(
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=original_exception.response,
|
||||||
)
|
)
|
||||||
|
elif "Request too large" in error_str:
|
||||||
|
raise RateLimitError(
|
||||||
|
message=f"{exception_provider} - {message} {extra_information}",
|
||||||
|
model=model,
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
response=original_exception.response,
|
||||||
|
)
|
||||||
elif (
|
elif (
|
||||||
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
|
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
|
||||||
in error_str
|
in error_str
|
||||||
|
@ -10467,6 +10470,12 @@ class CustomStreamWrapper:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def handle_bedrock_stream(self, chunk):
|
def handle_bedrock_stream(self, chunk):
|
||||||
|
if "cohere" in self.model:
|
||||||
|
return {
|
||||||
|
"text": chunk["text"],
|
||||||
|
"is_finished": chunk["is_finished"],
|
||||||
|
"finish_reason": chunk["finish_reason"],
|
||||||
|
}
|
||||||
if hasattr(chunk, "get"):
|
if hasattr(chunk, "get"):
|
||||||
chunk = chunk.get("chunk")
|
chunk = chunk.get("chunk")
|
||||||
chunk_data = json.loads(chunk.get("bytes").decode())
|
chunk_data = json.loads(chunk.get("bytes").decode())
|
||||||
|
@ -11315,6 +11324,7 @@ class CustomStreamWrapper:
|
||||||
or self.custom_llm_provider == "gemini"
|
or self.custom_llm_provider == "gemini"
|
||||||
or self.custom_llm_provider == "cached_response"
|
or self.custom_llm_provider == "cached_response"
|
||||||
or self.custom_llm_provider == "predibase"
|
or self.custom_llm_provider == "predibase"
|
||||||
|
or (self.custom_llm_provider == "bedrock" and "cohere" in self.model)
|
||||||
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||||
):
|
):
|
||||||
async for chunk in self.completion_stream:
|
async for chunk in self.completion_stream:
|
||||||
|
|
|
@ -2644,6 +2644,24 @@
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"cohere.command-r-plus-v1:0": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.0000030,
|
||||||
|
"output_cost_per_token": 0.000015,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"cohere.command-r-v1:0": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.0000005,
|
||||||
|
"output_cost_per_token": 0.0000015,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
"cohere.embed-english-v3": {
|
"cohere.embed-english-v3": {
|
||||||
"max_tokens": 512,
|
"max_tokens": 512,
|
||||||
"max_input_tokens": 512,
|
"max_input_tokens": 512,
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: gpt-3.5-turbo
|
- model_name: gpt-3.5-turbo
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/gpt-35-turbo
|
model: gpt-3.5-turbo
|
||||||
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
region_name: "eu"
|
||||||
api_key: os.environ/AZURE_EUROPE_API_KEY
|
model_info:
|
||||||
|
id: "1"
|
||||||
- model_name: gpt-3.5-turbo
|
- model_name: gpt-3.5-turbo
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/chatgpt-v-2
|
model: azure/chatgpt-v-2
|
||||||
|
@ -83,6 +84,7 @@ model_list:
|
||||||
model: text-completion-openai/gpt-3.5-turbo-instruct
|
model: text-completion-openai/gpt-3.5-turbo-instruct
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
drop_params: True
|
drop_params: True
|
||||||
|
enable_preview_features: True
|
||||||
# max_budget: 100
|
# max_budget: 100
|
||||||
# budget_duration: 30d
|
# budget_duration: 30d
|
||||||
num_retries: 5
|
num_retries: 5
|
||||||
|
|
|
@ -167,7 +167,4 @@ async def test_end_user_specific_region():
|
||||||
user=end_user_obj["user_id"],
|
user=end_user_obj["user_id"],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert result.headers.get("x-litellm-model-id") == "1"
|
||||||
result.headers.get("x-litellm-model-api-base")
|
|
||||||
== "https://my-endpoint-europe-berri-992.openai.azure.com/"
|
|
||||||
)
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue