Merge branch 'main' into patch-1

This commit is contained in:
Krish Dholakia 2024-03-05 07:49:48 -08:00 committed by GitHub
commit 03246af094
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
52 changed files with 1962 additions and 446 deletions

View file

@ -1,9 +1,12 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Anthropic
LiteLLM supports
- `claude-3` (`claude-3-opus-20240229`, `claude-3-sonnet-20240229`)
- `claude-2`
- `claude-2.1`
- `claude-instant-1`
- `claude-instant-1.2`
## API Keys
@ -24,11 +27,133 @@ from litellm import completion
os.environ["ANTHROPIC_API_KEY"] = "your-api-key"
messages = [{"role": "user", "content": "Hey! how's it going?"}]
response = completion(model="claude-instant-1", messages=messages)
response = completion(model="claude-3-opus-20240229", messages=messages)
print(response)
```
## Usage - "Assistant Pre-fill"
## Usage - Streaming
Just set `stream=True` when calling completion.
```python
import os
from litellm import completion
# set env
os.environ["ANTHROPIC_API_KEY"] = "your-api-key"
messages = [{"role": "user", "content": "Hey! how's it going?"}]
response = completion(model="claude-3-opus-20240229", messages=messages, stream=True)
for chunk in response:
print(chunk["choices"][0]["delta"]["content"]) # same as openai format
```
## OpenAI Proxy Usage
Here's how to call Anthropic with the LiteLLM Proxy Server
### 1. Save key in your environment
```bash
export ANTHROPIC_API_KEY="your-api-key"
```
### 2. Start the proxy
```bash
$ litellm --model claude-3-opus-20240229
# Server running on http://0.0.0.0:8000
```
### 3. Test it
<Tabs>
<TabItem value="Curl" label="Curl Request">
```shell
curl --location 'http://0.0.0.0:8000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
]
}
'
```
</TabItem>
<TabItem value="openai" label="OpenAI v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:8000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
])
print(response)
```
</TabItem>
<TabItem value="langchain" label="Langchain">
```python
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import HumanMessage, SystemMessage
chat = ChatOpenAI(
openai_api_base="http://0.0.0.0:8000", # set openai_api_base to the LiteLLM Proxy
model = "gpt-3.5-turbo",
temperature=0.1
)
messages = [
SystemMessage(
content="You are a helpful assistant that im using to make a test request to."
),
HumanMessage(
content="test from litellm. tell me why it's amazing in 1 sentence"
),
]
response = chat(messages)
print(response)
```
</TabItem>
</Tabs>
## Supported Models
| Model Name | Function Call |
|------------------|--------------------------------------------|
| claude-3-opus | `completion('claude-3-opus-20240229', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
| claude-3-sonnet | `completion('claude-3-sonnet-20240229', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
| claude-2.1 | `completion('claude-2.1', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
| claude-2 | `completion('claude-2', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
| claude-instant-1.2 | `completion('claude-instant-1.2', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
| claude-instant-1 | `completion('claude-instant-1', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
## Advanced
### Usage - "Assistant Pre-fill"
You can "put words in Claude's mouth" by including an `assistant` role message as the last item in the `messages` array.
@ -50,7 +175,7 @@ response = completion(model="claude-2.1", messages=messages)
print(response)
```
### Example prompt sent to Claude
#### Example prompt sent to Claude
```
@ -61,7 +186,7 @@ Human: How do you say 'Hello' in German? Return your answer as a JSON object, li
Assistant: {
```
## Usage - "System" messages
### Usage - "System" messages
If you're using Anthropic's Claude 2.1 with Bedrock, `system` role messages are properly formatted for you.
```python
@ -78,7 +203,7 @@ messages = [
response = completion(model="claude-2.1", messages=messages)
```
### Example prompt sent to Claude
#### Example prompt sent to Claude
```
You are a snarky assistant.
@ -88,28 +213,3 @@ Human: How do I boil water?
Assistant:
```
## Streaming
Just set `stream=True` when calling completion.
```python
import os
from litellm import completion
# set env
os.environ["ANTHROPIC_API_KEY"] = "your-api-key"
messages = [{"role": "user", "content": "Hey! how's it going?"}]
response = completion(model="claude-instant-1", messages=messages, stream=True)
for chunk in response:
print(chunk["choices"][0]["delta"]["content"]) # same as openai format
```
### Model Details
| Model Name | Function Call | Required OS Variables |
|------------------|--------------------------------------------|--------------------------------------|
| claude-2.1 | `completion('claude-2.1', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
| claude-2 | `completion('claude-2', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
| claude-instant-1 | `completion('claude-instant-1', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
| claude-instant-1.2 | `completion('claude-instant-1.2', messages)` | `os.environ['ANTHROPIC_API_KEY']` |

View file

@ -286,18 +286,20 @@ response = litellm.embedding(
## Supported AWS Bedrock Models
Here's an example of using a bedrock model with LiteLLM
| Model Name | Command |
|--------------------------|------------------------------------------------------------------|
| Model Name | Command |
|----------------------------|------------------------------------------------------------------|
| Anthropic Claude-V2.1 | `completion(model='bedrock/anthropic.claude-v2:1', messages=messages)` | `os.environ['ANTHROPIC_ACCESS_KEY_ID']`, `os.environ['ANTHROPIC_SECRET_ACCESS_KEY']` |
| Anthropic Claude-V2 | `completion(model='bedrock/anthropic.claude-v2', messages=messages)` | `os.environ['ANTHROPIC_ACCESS_KEY_ID']`, `os.environ['ANTHROPIC_SECRET_ACCESS_KEY']` |
| Anthropic Claude-V2 | `completion(model='bedrock/anthropic.claude-v2', messages=messages)` | `os.environ['ANTHROPIC_ACCESS_KEY_ID']`, `os.environ['ANTHROPIC_SECRET_ACCESS_KEY']` |
| Anthropic Claude-Instant V1 | `completion(model='bedrock/anthropic.claude-instant-v1', messages=messages)` | `os.environ['ANTHROPIC_ACCESS_KEY_ID']`, `os.environ['ANTHROPIC_SECRET_ACCESS_KEY']` |
| Amazon Titan Lite | `completion(model='bedrock/amazon.titan-text-lite-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Amazon Titan Express | `completion(model='bedrock/amazon.titan-text-express-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Cohere Command | `completion(model='bedrock/cohere.command-text-v14', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| AI21 J2-Mid | `completion(model='bedrock/ai21.j2-mid-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Amazon Titan Lite | `completion(model='bedrock/amazon.titan-text-lite-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Amazon Titan Express | `completion(model='bedrock/amazon.titan-text-express-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Cohere Command | `completion(model='bedrock/cohere.command-text-v14', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| AI21 J2-Mid | `completion(model='bedrock/ai21.j2-mid-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| AI21 J2-Ultra | `completion(model='bedrock/ai21.j2-ultra-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Meta Llama 2 Chat 13b | `completion(model='bedrock/meta.llama2-13b-chat-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Meta Llama 2 Chat 70b | `completion(model='bedrock/meta.llama2-70b-chat-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Meta Llama 2 Chat 13b | `completion(model='bedrock/meta.llama2-13b-chat-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Meta Llama 2 Chat 70b | `completion(model='bedrock/meta.llama2-70b-chat-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Mistral 7B Instruct | `completion(model='bedrock/mistral.mistral-7b-instruct-v0:2', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Mixtral 8x7B Instruct | `completion(model='bedrock/mistral.mixtral-8x7b-instruct-v0:1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
## Bedrock Embedding

View file

@ -202,7 +202,7 @@ print(response)
</Tabs>
## Save Model-specific params (API Base, API Keys, Temperature, Max Tokens, Seed, Organization, Headers etc.)
## Save Model-specific params (API Base, Keys, Temperature, Max Tokens, Organization, Headers etc.)
You can use the config to save model-specific information like api_base, api_key, temperature, max_tokens, etc.
[**All input params**](https://docs.litellm.ai/docs/completion/input#input-params-1)
@ -244,6 +244,45 @@ $ litellm --config /path/to/config.yaml
```
## Load Balancing
Use this to call multiple instances of the same model and configure things like [routing strategy](../routing.md#advanced).
```yaml
router_settings:
routing_strategy: "latency-based-routing" # routes to the fastest deployment in the group
model_list:
- model_name: zephyr-beta
litellm_params:
model: huggingface/HuggingFaceH4/zephyr-7b-beta
api_base: http://0.0.0.0:8001
- model_name: zephyr-beta
litellm_params:
model: huggingface/HuggingFaceH4/zephyr-7b-beta
api_base: http://0.0.0.0:8002
- model_name: zephyr-beta
litellm_params:
model: huggingface/HuggingFaceH4/zephyr-7b-beta
api_base: http://0.0.0.0:8003
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
api_key: <my-openai-key>
- model_name: gpt-3.5-turbo-16k
litellm_params:
model: gpt-3.5-turbo-16k
api_key: <my-openai-key>
litellm_settings:
num_retries: 3 # retry call 3 times on each model_name (e.g. zephyr-beta)
request_timeout: 10 # raise Timeout error if call takes longer than 10s. Sets litellm.request_timeout
fallbacks: [{"zephyr-beta": ["gpt-3.5-turbo"]}] # fallback to gpt-3.5-turbo if call fails num_retries
context_window_fallbacks: [{"zephyr-beta": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}] # fallback to gpt-3.5-turbo-16k if context window error
allowed_fails: 3 # cooldown model if it fails > 1 call in a minute.
```
## Set Azure `base_model` for cost tracking
**Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking
@ -512,30 +551,6 @@ curl --location 'http://0.0.0.0:8000/chat/completions' \
```
## Router Settings
Use this to configure things like routing strategy.
```yaml
router_settings:
routing_strategy: "least-busy"
model_list: # will route requests to the least busy ollama model
- model_name: ollama-models
litellm_params:
model: "ollama/mistral"
api_base: "http://127.0.0.1:8001"
- model_name: ollama-models
litellm_params:
model: "ollama/codellama"
api_base: "http://127.0.0.1:8002"
- model_name: ollama-models
litellm_params:
model: "ollama/llama2"
api_base: "http://127.0.0.1:8003"
```
## Configure DB Pool Limits + Connection Timeouts
```yaml

View file

@ -570,9 +570,11 @@ from .utils import (
_calculate_retry_after,
_should_retry,
get_secret,
get_mapped_model_params,
)
from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig
from .llms.anthropic_text import AnthropicTextConfig
from .llms.replicate import ReplicateConfig
from .llms.cohere import CohereConfig
from .llms.ai21 import AI21Config
@ -591,9 +593,11 @@ from .llms.bedrock import (
AmazonTitanConfig,
AmazonAI21Config,
AmazonAnthropicConfig,
AmazonAnthropicClaude3Config,
AmazonCohereConfig,
AmazonLlamaConfig,
AmazonStabilityConfig,
AmazonMistralConfig,
)
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError

View file

@ -77,9 +77,9 @@ class AlephAlphaConfig:
- `control_log_additive` (boolean; default value: true): Method of applying control to attention scores.
"""
maximum_tokens: Optional[
int
] = litellm.max_tokens # aleph alpha requires max tokens
maximum_tokens: Optional[int] = (
litellm.max_tokens
) # aleph alpha requires max tokens
minimum_tokens: Optional[int] = None
echo: Optional[bool] = None
temperature: Optional[int] = None
@ -285,7 +285,10 @@ def completion(
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"]["content"])
encoding.encode(
model_response["choices"][0]["message"]["content"],
disallowed_special=(),
)
)
model_response["created"] = int(time.time())

View file

@ -2,11 +2,17 @@ import os, types
import json
from enum import Enum
import requests
import time
import time, uuid
from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage
from litellm.utils import ModelResponse, Usage, map_finish_reason
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from .prompt_templates.factory import (
prompt_factory,
custom_prompt,
construct_tool_use_system_prompt,
extract_between_tags,
parse_xml_params,
)
import httpx
@ -20,7 +26,7 @@ class AnthropicError(Exception):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://api.anthropic.com/v1/complete"
method="POST", url="https://api.anthropic.com/v1/messages"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
@ -35,23 +41,23 @@ class AnthropicConfig:
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
"""
max_tokens_to_sample: Optional[
int
] = litellm.max_tokens # anthropic requires a default
max_tokens: Optional[int] = litellm.max_tokens # anthropic requires a default
stop_sequences: Optional[list] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
top_k: Optional[int] = None
metadata: Optional[dict] = None
system: Optional[str] = None
def __init__(
self,
max_tokens_to_sample: Optional[int] = 256, # anthropic requires a default
max_tokens: Optional[int] = 256, # anthropic requires a default
stop_sequences: Optional[list] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
top_k: Optional[int] = None,
metadata: Optional[dict] = None,
system: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
@ -110,6 +116,7 @@ def completion(
headers={},
):
headers = validate_environment(api_key, headers)
_is_function_call = False
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
@ -120,7 +127,17 @@ def completion(
messages=messages,
)
else:
prompt = prompt_factory(
# Separate system prompt from rest of message
system_prompt_idx: Optional[int] = None
for idx, message in enumerate(messages):
if message["role"] == "system":
optional_params["system"] = message["content"]
system_prompt_idx = idx
break
if system_prompt_idx is not None:
messages.pop(system_prompt_idx)
# Format rest of message according to anthropic guidelines
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
@ -132,15 +149,26 @@ def completion(
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=optional_params["tools"]
)
optional_params["system"] = (
optional_params.get("system", "\n") + tool_calling_system_prompt
) # add the anthropic tool calling prompt to the system prompt
optional_params.pop("tools")
data = {
"model": model,
"prompt": prompt,
"messages": messages,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=prompt,
input=messages,
api_key=api_key,
additional_args={
"complete_input_dict": data,
@ -173,7 +201,7 @@ def completion(
## LOGGING
logging_obj.post_call(
input=prompt,
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
@ -191,20 +219,45 @@ def completion(
message=str(completion_response["error"]),
status_code=response.status_code,
)
elif len(completion_response["content"]) == 0:
raise AnthropicError(
message="No content in response",
status_code=response.status_code,
)
else:
if len(completion_response["completion"]) > 0:
model_response["choices"][0]["message"][
"content"
] = completion_response["completion"]
model_response.choices[0].finish_reason = completion_response["stop_reason"]
text_content = completion_response["content"][0].get("text", None)
## TOOL CALLING - OUTPUT PARSE
if text_content is not None and "invoke" in text_content:
function_name = extract_between_tags("tool_name", text_content)[0]
function_arguments_str = extract_between_tags("invoke", text_content)[
0
].strip()
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
function_arguments = parse_xml_params(function_arguments_str)
_message = litellm.Message(
tool_calls=[
{
"id": f"call_{uuid.uuid4()}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
},
}
],
content=None,
)
model_response.choices[0].message = _message # type: ignore
else:
model_response.choices[0].message.content = text_content # type: ignore
model_response.choices[0].finish_reason = map_finish_reason(
completion_response["stop_reason"]
)
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt,disallowed_special=())
) ##[TODO] use the anthropic tokenizer here
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""),disallowed_special=())
) ##[TODO] use the anthropic tokenizer here
prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"]
total_tokens = prompt_tokens + completion_tokens
model_response["created"] = int(time.time())
model_response["model"] = model

View file

@ -0,0 +1,222 @@
import os, types
import json
from enum import Enum
import requests
import time
from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
import httpx
class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: "
class AnthropicError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://api.anthropic.com/v1/complete"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class AnthropicTextConfig:
"""
Reference: https://docs.anthropic.com/claude/reference/complete_post
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
"""
max_tokens_to_sample: Optional[int] = (
litellm.max_tokens
) # anthropic requires a default
stop_sequences: Optional[list] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
top_k: Optional[int] = None
metadata: Optional[dict] = None
def __init__(
self,
max_tokens_to_sample: Optional[int] = 256, # anthropic requires a default
stop_sequences: Optional[list] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
top_k: Optional[int] = None,
metadata: Optional[dict] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
# makes headers for API call
def validate_environment(api_key, user_headers):
if api_key is None:
raise ValueError(
"Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params"
)
headers = {
"accept": "application/json",
"anthropic-version": "2023-06-01",
"content-type": "application/json",
"x-api-key": api_key,
}
if user_headers is not None and isinstance(user_headers, dict):
headers = {**headers, **user_headers}
return headers
def completion(
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
):
headers = validate_environment(api_key, headers)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
## Load Config
config = litellm.AnthropicTextConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
"model": model,
"prompt": prompt,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True:
response = requests.post(
api_base,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"],
)
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
return response.iter_lines()
else:
response = requests.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise AnthropicError(
message=response.text, status_code=response.status_code
)
if "error" in completion_response:
raise AnthropicError(
message=str(completion_response["error"]),
status_code=response.status_code,
)
else:
if len(completion_response["completion"]) > 0:
model_response["choices"][0]["message"]["content"] = (
completion_response["completion"]
)
model_response.choices[0].finish_reason = completion_response["stop_reason"]
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt)
) ##[TODO] use the anthropic tokenizer here
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) ##[TODO] use the anthropic tokenizer here
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -1,11 +1,17 @@
import json, copy, types
import os
from enum import Enum
import time
import time, uuid
from typing import Callable, Optional, Any, Union, List
import litellm
from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse
from .prompt_templates.factory import prompt_factory, custom_prompt
from .prompt_templates.factory import (
prompt_factory,
custom_prompt,
construct_tool_use_system_prompt,
extract_between_tags,
parse_xml_params,
)
import httpx
@ -70,6 +76,59 @@ class AmazonTitanConfig:
}
class AmazonAnthropicClaude3Config:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
Supported Params for the Amazon / Anthropic Claude 3 models:
- `max_tokens` (integer) max tokens,
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
"""
max_tokens: Optional[int] = litellm.max_tokens
anthropic_version: Optional[str] = "bedrock-2023-05-31"
def __init__(
self,
max_tokens: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return ["max_tokens", "tools", "tool_choice", "stream"]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
return optional_params
class AmazonAnthropicConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
@ -123,6 +182,25 @@ class AmazonAnthropicConfig:
and v is not None
}
def get_supported_openai_params(
self,
):
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens_to_sample"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "stream" and value == True:
optional_params["stream"] = value
return optional_params
class AmazonCohereConfig:
"""
@ -282,6 +360,56 @@ class AmazonLlamaConfig:
}
class AmazonMistralConfig:
"""
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
Supported Params for the Amazon / Mistral models:
- `max_tokens` (integer) max tokens,
- `temperature` (float) temperature for model,
- `top_p` (float) top p for model
- `stop` [string] A list of stop sequences that if generated by the model, stops the model from generating further output.
- `top_k` (float) top k for model
"""
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[float] = None
stop: Optional[list[str]] = None
def __init__(
self,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[int] = None,
top_k: Optional[float] = None,
stop: Optional[list[str]] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AmazonStabilityConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
@ -492,6 +620,10 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "mistral":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
else:
prompt = ""
for message in messages:
@ -568,14 +700,47 @@ def completion(
inference_params = copy.deepcopy(optional_params)
stream = inference_params.pop("stream", False)
if provider == "anthropic":
## LOAD CONFIG
config = litellm.AmazonAnthropicConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({"prompt": prompt, **inference_params})
if model.startswith("anthropic.claude-3"):
# Separate system prompt from rest of message
system_prompt_idx: Optional[int] = None
for idx, message in enumerate(messages):
if message["role"] == "system":
inference_params["system"] = message["content"]
system_prompt_idx = idx
break
if system_prompt_idx is not None:
messages.pop(system_prompt_idx)
# Format rest of message according to anthropic guidelines
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
## LOAD CONFIG
config = litellm.AmazonAnthropicClaude3Config.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
## Handle Tool Calling
if "tools" in inference_params:
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=inference_params["tools"]
)
inference_params["system"] = (
inference_params.get("system", "\n")
+ tool_calling_system_prompt
) # add the anthropic tool calling prompt to the system prompt
inference_params.pop("tools")
data = json.dumps({"messages": messages, **inference_params})
else:
## LOAD CONFIG
config = litellm.AmazonAnthropicConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "ai21":
## LOAD CONFIG
config = litellm.AmazonAI21Config.get_config()
@ -595,9 +760,9 @@ def completion(
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
if optional_params.get("stream", False) == True:
inference_params[
"stream"
] = True # cohere requires stream = True in inference params
inference_params["stream"] = (
True # cohere requires stream = True in inference params
)
data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "meta":
## LOAD CONFIG
@ -623,7 +788,16 @@ def completion(
"textGenerationConfig": inference_params,
}
)
elif provider == "mistral":
## LOAD CONFIG
config = litellm.AmazonMistralConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({"prompt": prompt, **inference_params})
else:
data = json.dumps({})
@ -723,12 +897,49 @@ def completion(
if provider == "ai21":
outputText = response_body.get("completions")[0].get("data").get("text")
elif provider == "anthropic":
outputText = response_body["completion"]
model_response["finish_reason"] = response_body["stop_reason"]
if model.startswith("anthropic.claude-3"):
outputText = response_body.get("content")[0].get("text", None)
if "<invoke>" in outputText: # OUTPUT PARSE FUNCTION CALL
function_name = extract_between_tags("tool_name", outputText)[0]
function_arguments_str = extract_between_tags("invoke", outputText)[
0
].strip()
function_arguments_str = (
f"<invoke>{function_arguments_str}</invoke>"
)
function_arguments = parse_xml_params(function_arguments_str)
_message = litellm.Message(
tool_calls=[
{
"id": f"call_{uuid.uuid4()}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
},
}
],
content=None,
)
model_response.choices[0].message = _message # type: ignore
model_response["finish_reason"] = response_body["stop_reason"]
_usage = litellm.Usage(
prompt_tokens=response_body["usage"]["input_tokens"],
completion_tokens=response_body["usage"]["output_tokens"],
total_tokens=response_body["usage"]["input_tokens"]
+ response_body["usage"]["output_tokens"],
)
model_response.usage = _usage
else:
outputText = response_body["completion"]
model_response["finish_reason"] = response_body["stop_reason"]
elif provider == "cohere":
outputText = response_body["generations"][0]["text"]
elif provider == "meta":
outputText = response_body["generation"]
elif provider == "mistral":
outputText = response_body["outputs"][0]["text"]
model_response["finish_reason"] = response_body["outputs"][0]["stop_reason"]
else: # amazon titan
outputText = response_body.get("results")[0].get("outputText")
@ -740,8 +951,19 @@ def completion(
)
else:
try:
if len(outputText) > 0:
if (
len(outputText) > 0
and hasattr(model_response.choices[0], "message")
and getattr(model_response.choices[0].message, "tool_calls", None)
is None
):
model_response["choices"][0]["message"]["content"] = outputText
elif (
hasattr(model_response.choices[0], "message")
and getattr(model_response.choices[0].message, "tool_calls", None)
is not None
):
pass
else:
raise Exception()
except:
@ -751,26 +973,28 @@ def completion(
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = response_metadata.get(
"x-amzn-bedrock-input-token-count", len(encoding.encode(prompt))
)
completion_tokens = response_metadata.get(
"x-amzn-bedrock-output-token-count",
len(
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
),
)
if getattr(model_response.usage, "total_tokens", None) is None:
prompt_tokens = response_metadata.get(
"x-amzn-bedrock-input-token-count", len(encoding.encode(prompt))
)
completion_tokens = response_metadata.get(
"x-amzn-bedrock-output-token-count",
len(
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
),
)
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
model_response._hidden_params["region_name"] = client.meta.region_name
print_verbose(f"model_response._hidden_params: {model_response._hidden_params}")
return model_response

View file

@ -634,8 +634,43 @@ class Huggingface(BaseLLM):
status_code=r.status_code,
message=str(text),
)
"""
Check first chunk for error message.
If error message, raise error.
If not - add back to stream
"""
# Async iterator over the lines in the response body
response_iterator = r.aiter_lines()
# Attempt to get the first line/chunk from the response
try:
first_chunk = await response_iterator.__anext__()
except StopAsyncIteration:
# Handle the case where there are no lines to read (empty response)
first_chunk = ""
# Check the first chunk for an error message
if (
"error" in first_chunk.lower()
): # Adjust this condition based on how error messages are structured
raise HuggingfaceError(
status_code=400,
message=first_chunk,
)
# Create a new async generator that begins with the first_chunk and includes the remaining items
async def custom_stream_with_first_chunk():
yield first_chunk # Yield back the first chunk
async for (
chunk
) in response_iterator: # Continue yielding the rest of the chunks
yield chunk
# Creating a new completion stream that starts with the first chunk
completion_stream = custom_stream_with_first_chunk()
streamwrapper = CustomStreamWrapper(
completion_stream=r.aiter_lines(),
completion_stream=completion_stream,
model=model,
custom_llm_provider="huggingface",
logging_obj=logging_obj,

View file

@ -1,8 +1,9 @@
from enum import Enum
import requests, traceback
import json
import json, re, xml.etree.ElementTree as ET
from jinja2 import Template, exceptions, Environment, meta
from typing import Optional, Any
import imghdr, base64
def default_pt(messages):
@ -110,9 +111,9 @@ def mistral_instruct_pt(messages):
"post_message": " [/INST]\n",
},
"user": {"pre_message": "[INST] ", "post_message": " [/INST]\n"},
"assistant": {"pre_message": " ", "post_message": " "},
"assistant": {"pre_message": " ", "post_message": "</s> "},
},
final_prompt_value="</s>",
final_prompt_value="",
messages=messages,
)
return prompt
@ -390,7 +391,7 @@ def format_prompt_togetherai(messages, prompt_format, chat_template):
return prompt
###
### ANTHROPIC ###
def anthropic_pt(
@ -424,6 +425,184 @@ def anthropic_pt(
return prompt
def construct_format_parameters_prompt(parameters: dict):
parameter_str = "<parameter>\n"
for k, v in parameters.items():
parameter_str += f"<{k}>"
parameter_str += f"{v}"
parameter_str += f"</{k}>"
parameter_str += "\n</parameter>"
return parameter_str
def construct_format_tool_for_claude_prompt(name, description, parameters):
constructed_prompt = (
"<tool_description>\n"
f"<tool_name>{name}</tool_name>\n"
"<description>\n"
f"{description}\n"
"</description>\n"
"<parameters>\n"
f"{construct_format_parameters_prompt(parameters)}\n"
"</parameters>\n"
"</tool_description>"
)
return constructed_prompt
def construct_tool_use_system_prompt(
tools,
): # from https://github.com/anthropics/anthropic-cookbook/blob/main/function_calling/function_calling.ipynb
tool_str_list = []
for tool in tools:
tool_str = construct_format_tool_for_claude_prompt(
tool["function"]["name"],
tool["function"].get("description", ""),
tool["function"].get("parameters", {}),
)
tool_str_list.append(tool_str)
tool_use_system_prompt = (
"In this environment you have access to a set of tools you can use to answer the user's question.\n"
"\n"
"You may call them like this:\n"
"<function_calls>\n"
"<invoke>\n"
"<tool_name>$TOOL_NAME</tool_name>\n"
"<parameters>\n"
"<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>\n"
"...\n"
"</parameters>\n"
"</invoke>\n"
"</function_calls>\n"
"\n"
"Here are the tools available:\n"
"<tools>\n" + "\n".join([tool_str for tool_str in tool_str_list]) + "\n</tools>"
)
return tool_use_system_prompt
def convert_to_anthropic_image_obj(openai_image_url: str):
"""
Input:
"image_url": "data:image/jpeg;base64,{base64_image}",
Return:
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": {base64_image},
}
"""
# Extract the base64 image data
base64_data = openai_image_url.split("data:image/")[1].split(";base64,")[1]
# Infer image format from the URL
image_format = openai_image_url.split("data:image/")[1].split(";base64,")[0]
return {
"type": "base64",
"media_type": f"image/{image_format}",
"data": base64_data,
}
def anthropic_messages_pt(messages: list):
"""
format messages for anthropic
1. Anthropic supports roles like "user" and "assistant", (here litellm translates system-> assistant)
2. The first message always needs to be of role "user"
3. Each message must alternate between "user" and "assistant" (this is not addressed as now by litellm)
4. final assistant content cannot end with trailing whitespace (anthropic raises an error otherwise)
5. System messages are a separate param to the Messages API (used for tool calling)
"""
## Ensure final assistant message has no trailing whitespace
last_assistant_message_idx: Optional[int] = None
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
new_messages = []
if len(messages) == 1:
# check if the message is a user message
if messages[0]["role"] == "assistant":
new_messages.append({"role": "user", "content": ""})
# check if content is a list (vision)
if isinstance(messages[0]["content"], list): # vision input
new_content = []
for m in messages[0]["content"]:
if m.get("type", "") == "image_url":
new_content.append(
{
"type": "image",
"source": convert_to_anthropic_image_obj(
m["image_url"]["url"]
),
}
)
elif m.get("type", "") == "text":
new_content.append({"type": "text", "text": m["text"]})
new_messages.append({"role": messages[0]["role"], "content": new_content}) # type: ignore
else:
new_messages.append(messages[0])
return new_messages
for i in range(len(messages) - 1): # type: ignore
if i == 0 and messages[i]["role"] == "assistant":
new_messages.append({"role": "user", "content": ""})
if isinstance(messages[i]["content"], list): # vision input
new_content = []
for m in messages[i]["content"]:
if m.get("type", "") == "image_url":
new_content.append(
{
"type": "image",
"source": convert_to_anthropic_image_obj(
m["image_url"]["url"]
),
}
)
elif m.get("type", "") == "text":
new_content.append({"type": "text", "content": m["text"]})
new_messages.append({"role": messages[i]["role"], "content": new_content}) # type: ignore
else:
new_messages.append(messages[i])
if messages[i]["role"] == messages[i + 1]["role"]:
if messages[i]["role"] == "user":
new_messages.append({"role": "assistant", "content": ""})
else:
new_messages.append({"role": "user", "content": ""})
if messages[i]["role"] == "assistant":
last_assistant_message_idx = i
if last_assistant_message_idx is not None:
new_messages[last_assistant_message_idx]["content"] = new_messages[
last_assistant_message_idx
][
"content"
].strip() # no trailing whitespace for final assistant message
return new_messages
def extract_between_tags(tag: str, string: str, strip: bool = False) -> list[str]:
ext_list = re.findall(f"<{tag}>(.+?)</{tag}>", string, re.DOTALL)
if strip:
ext_list = [e.strip() for e in ext_list]
return ext_list
def parse_xml_params(xml_content):
root = ET.fromstring(xml_content)
params = {}
for child in root.findall(".//parameters/*"):
params[child.tag] = child.text
return params
###
def amazon_titan_pt(
messages: list,
): # format - https://github.com/BerriAI/litellm/issues/1896
@ -650,10 +829,9 @@ def prompt_factory(
if custom_llm_provider == "ollama":
return ollama_pt(model=model, messages=messages)
elif custom_llm_provider == "anthropic":
if any(_ in model for _ in ["claude-2.1", "claude-v2:1"]):
return claude_2_1_pt(messages=messages)
else:
if model == "claude-instant-1" or model == "claude-2":
return anthropic_pt(messages=messages)
return anthropic_messages_pt(messages=messages)
elif custom_llm_provider == "together_ai":
prompt_format, chat_template = get_model_info(token=api_key, model=model)
return format_prompt_togetherai(
@ -674,6 +852,8 @@ def prompt_factory(
return claude_2_1_pt(messages=messages)
else:
return anthropic_pt(messages=messages)
elif "mistral." in model:
return mistral_instruct_pt(messages=messages)
try:
if "meta-llama/llama-2" in model and "chat" in model:
return llama_2_chat_pt(messages=messages)

View file

@ -40,6 +40,7 @@ from litellm.utils import (
)
from .llms import (
anthropic,
anthropic_text,
together_ai,
ai21,
sagemaker,
@ -1019,28 +1020,55 @@ def completion(
or litellm.api_key
or os.environ.get("ANTHROPIC_API_KEY")
)
api_base = (
api_base
or litellm.api_base
or get_secret("ANTHROPIC_API_BASE")
or "https://api.anthropic.com/v1/complete"
)
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = anthropic.completion(
model=model,
messages=messages,
api_base=api_base,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding, # for calculating input/output tokens
api_key=api_key,
logging_obj=logging,
headers=headers,
)
if (model == "claude-2") or (model == "claude-instant-1"):
# call anthropic /completion, only use this route for claude-2, claude-instant-1
api_base = (
api_base
or litellm.api_base
or get_secret("ANTHROPIC_API_BASE")
or "https://api.anthropic.com/v1/complete"
)
response = anthropic_text.completion(
model=model,
messages=messages,
api_base=api_base,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding, # for calculating input/output tokens
api_key=api_key,
logging_obj=logging,
headers=headers,
)
else:
# call /messages
# default route for all anthropic models
api_base = (
api_base
or litellm.api_base
or get_secret("ANTHROPIC_API_BASE")
or "https://api.anthropic.com/v1/messages"
)
response = anthropic.completion(
model=model,
messages=messages,
api_base=api_base,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding, # for calculating input/output tokens
api_key=api_key,
logging_obj=logging,
headers=headers,
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(

View file

@ -643,6 +643,22 @@
"litellm_provider": "anthropic",
"mode": "chat"
},
"claude-3-opus-20240229": {
"max_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000075,
"litellm_provider": "anthropic",
"mode": "chat"
},
"claude-3-sonnet-20240229": {
"max_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "anthropic",
"mode": "chat"
},
"text-bison": {
"max_tokens": 8192,
"input_cost_per_token": 0.000000125,
@ -1236,6 +1252,29 @@
"litellm_provider": "bedrock",
"mode": "embedding"
},
"bedrock/us-west-2/mistral.mixtral-8x7b-instruct": {
"max_tokens": 32000,
"input_cost_per_token": 0.00000045,
"output_cost_per_token": 0.0000007,
"litellm_provider": "bedrock",
"mode": "completion"
},
"bedrock/us-west-2/mistral.mistral-7b-instruct": {
"max_tokens": 32000,
"input_cost_per_token": 0.00000015,
"output_cost_per_token": 0.0000002,
"litellm_provider": "bedrock",
"mode": "completion"
},
"anthropic.claude-3-sonnet-20240229-v1:0": {
"max_tokens": 200000,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
"mode": "chat"
},
"anthropic.claude-v1": {
"max_tokens": 100000,
"max_output_tokens": 8191,
@ -2220,4 +2259,4 @@
"mode": "embedding"
}
}
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1 +1 @@
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-59d9232c3e7a8be6.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-a85b2c176012d8e5.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-e1b183dda365ec86.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-9b4fb13a7db53edf.js" async="" crossorigin=""></script><title>🚅 LiteLLM</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-59d9232c3e7a8be6.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/32e93a3d13512de5.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[56239,[\"730\",\"static/chunks/730-1411b729a1c79695.js\",\"931\",\"static/chunks/app/page-37bd7c3d0bb898a3.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/32e93a3d13512de5.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"p1zjZBLDqxGf-NaFvZkeF\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"🚅 LiteLLM\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-59d9232c3e7a8be6.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-a85b2c176012d8e5.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-e1b183dda365ec86.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-9b4fb13a7db53edf.js" async="" crossorigin=""></script><title>🚅 LiteLLM</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-59d9232c3e7a8be6.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/32e93a3d13512de5.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[57492,[\"730\",\"static/chunks/730-1411b729a1c79695.js\",\"931\",\"static/chunks/app/page-2ed0bc91ffef505b.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/32e93a3d13512de5.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"ZF-EluyKCEJoZptE3dOXT\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"🚅 LiteLLM\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>

View file

@ -1,7 +1,7 @@
2:I[77831,[],""]
3:I[56239,["730","static/chunks/730-1411b729a1c79695.js","931","static/chunks/app/page-37bd7c3d0bb898a3.js"],""]
3:I[57492,["730","static/chunks/730-1411b729a1c79695.js","931","static/chunks/app/page-2ed0bc91ffef505b.js"],""]
4:I[5613,[],""]
5:I[31778,[],""]
0:["p1zjZBLDqxGf-NaFvZkeF",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/32e93a3d13512de5.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
0:["ZF-EluyKCEJoZptE3dOXT",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/32e93a3d13512de5.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
6:[["$","meta","0",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","1",{"charSet":"utf-8"}],["$","title","2",{"children":"🚅 LiteLLM"}],["$","meta","3",{"name":"description","content":"LiteLLM Proxy Admin UI"}],["$","link","4",{"rel":"icon","href":"/ui/favicon.ico","type":"image/x-icon","sizes":"16x16"}],["$","meta","5",{"name":"next-size-adjust"}]]
1:null

View file

@ -19,6 +19,8 @@ telemetry = None
def append_query_params(url, params):
print(f"url: {url}")
print(f"params: {params}")
parsed_url = urlparse.urlparse(url)
parsed_query = urlparse.parse_qs(parsed_url.query)
parsed_query.update(params)

View file

@ -167,6 +167,15 @@ class ProxyException(Exception):
self.param = param
self.code = code
def to_dict(self) -> dict:
"""Converts the ProxyException instance to a dictionary."""
return {
"message": self.message,
"type": self.type,
"param": self.param,
"code": self.code,
}
@app.exception_handler(ProxyException)
async def openai_exception_handler(request: Request, exc: ProxyException):
@ -2241,12 +2250,14 @@ async def async_data_generator(response, user_api_key_dict):
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
raise ProxyException(
proxy_exception = ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
error_returned = json.dumps({"error": proxy_exception.to_dict()})
yield f"data: {error_returned}\n\n"
def select_data_generator(response, user_api_key_dict):
@ -5800,6 +5811,58 @@ async def model_info_v2(
return {"data": all_models}
@router.get(
"/model/metrics",
description="View number of requests & avg latency per model on config.yaml",
tags=["model management"],
dependencies=[Depends(user_api_key_auth)],
)
async def model_metrics(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global prisma_client
if prisma_client is None:
raise ProxyException(
message="Prisma Client is not initialized",
type="internal_error",
param="None",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
sql_query = """
SELECT
CASE WHEN api_base = '' THEN model ELSE CONCAT(model, '-', api_base) END AS combined_model_api_base,
COUNT(*) AS num_requests,
AVG(EXTRACT(epoch FROM ("endTime" - "startTime"))) AS avg_latency_seconds
FROM
"LiteLLM_SpendLogs"
WHERE
"startTime" >= NOW() - INTERVAL '10000 hours'
GROUP BY
CASE WHEN api_base = '' THEN model ELSE CONCAT(model, '-', api_base) END
ORDER BY
num_requests DESC
LIMIT 50;
"""
db_response = await prisma_client.db.query_raw(query=sql_query)
response: List[dict] = []
if response is not None:
# loop through all models
for model_data in db_response:
model = model_data.get("combined_model_api_base", "")
num_requests = model_data.get("num_requests", 0)
avg_latency_seconds = model_data.get("avg_latency_seconds", 0)
response.append(
{
"model": model,
"num_requests": num_requests,
"avg_latency_seconds": avg_latency_seconds,
}
)
return response
@router.get(
"/model/info",
description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)",

118
litellm/tests/log.txt Normal file
View file

@ -0,0 +1,118 @@
============================= test session starts ==============================
platform darwin -- Python 3.11.6, pytest-7.3.1, pluggy-1.3.0
rootdir: /Users/krrishdholakia/Documents/litellm/litellm/tests
plugins: timeout-2.2.0, asyncio-0.23.2, anyio-3.7.1, xdist-3.3.1
asyncio: mode=Mode.STRICT
collected 1 item
test_custom_callback_input.py . [100%]
=============================== warnings summary ===============================
../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:271
../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:271
../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:271
../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:271
../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:271
../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:271
../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:271
../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:271
/opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_config.py:271: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
../proxy/_types.py:99
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:99: PydanticDeprecatedSince20: `pydantic.config.Extra` is deprecated, use literal values instead (e.g. `extra='allow'`). Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
extra = Extra.allow # Allow extra fields
../proxy/_types.py:102
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:102: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../proxy/_types.py:131
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:131: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../proxy/_types.py:177
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:177: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../proxy/_types.py:232
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:232: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../proxy/_types.py:244
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:244: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../proxy/_types.py:279
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:279: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../proxy/_types.py:305
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:305: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../../../../../../opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_fields.py:149
/opt/homebrew/lib/python3.11/site-packages/pydantic/_internal/_fields.py:149: UserWarning: Field "model_max_budget" has conflict with protected namespace "model_".
You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ()`.
warnings.warn(
../proxy/_types.py:553
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:553: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../proxy/_types.py:574
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:574: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../utils.py:36
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:36: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
import pkg_resources
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871: 10 warnings
/opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(pkg)
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871
/opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google.cloud')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(pkg)
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2350
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2350
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2350
/opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2350: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(parent)
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871
/opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google.logging')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(pkg)
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871
/opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google.iam')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(pkg)
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871
/opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('mpl_toolkits')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(pkg)
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871
/opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('sphinxcontrib')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(pkg)
../llms/prompt_templates/factory.py:6
/Users/krrishdholakia/Documents/litellm/litellm/llms/prompt_templates/factory.py:6: DeprecationWarning: 'imghdr' is deprecated and slated for removal in Python 3.13
import imghdr, base64
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================= 1 passed, 43 warnings in 13.05s ========================

View file

@ -41,14 +41,16 @@ def test_function_call_non_openai_model():
pass
test_function_call_non_openai_model()
# test_function_call_non_openai_model()
## case 2: add_function_to_prompt set
@pytest.mark.skip(reason="Anthropic now supports tool calling")
def test_function_call_non_openai_model_litellm_mod_set():
litellm.add_function_to_prompt = True
litellm.set_verbose = True
try:
model = "claude-instant-1"
model = "claude-instant-1.2"
messages = [{"role": "user", "content": "what's the weather in sf?"}]
functions = [
{

View file

@ -1,4 +1,4 @@
## @pytest.mark.skip(reason="AWS Suspended Account")
# # @pytest.mark.skip(reason="AWS Suspended Account")
# import sys
# import os
# import io, asyncio

View file

@ -351,7 +351,7 @@ def test_gemini_pro_vision_base64():
load_vertex_ai_credentials()
litellm.set_verbose = True
litellm.num_retries = 3
image_path = "cached_logo.jpg"
image_path = "../proxy/cached_logo.jpg"
# Getting the base64 string
base64_image = encode_image(image_path)
resp = litellm.completion(

View file

@ -1,259 +1,401 @@
# @pytest.mark.skip(reason="AWS Suspended Account")
# import sys, os
# import traceback
# from dotenv import load_dotenv
import sys, os
import traceback
from dotenv import load_dotenv
# load_dotenv()
# import os, io
load_dotenv()
import os, io
# sys.path.insert(
# 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path
# import pytest
# import litellm
# from litellm import embedding, completion, completion_cost, Timeout
# from litellm import RateLimitError
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import embedding, completion, completion_cost, Timeout, ModelResponse
from litellm import RateLimitError
# # litellm.num_retries = 3
# litellm.cache = None
# litellm.success_callback = []
# user_message = "Write a short poem about the sky"
# messages = [{"content": user_message, "role": "user"}]
# litellm.num_retries = 3
litellm.cache = None
litellm.success_callback = []
user_message = "Write a short poem about the sky"
messages = [{"content": user_message, "role": "user"}]
# @pytest.fixture(autouse=True)
# def reset_callbacks():
# print("\npytest fixture - resetting callbacks")
# litellm.success_callback = []
# litellm._async_success_callback = []
# litellm.failure_callback = []
# litellm.callbacks = []
@pytest.fixture(autouse=True)
def reset_callbacks():
print("\npytest fixture - resetting callbacks")
litellm.success_callback = []
litellm._async_success_callback = []
litellm.failure_callback = []
litellm.callbacks = []
# def test_completion_bedrock_claude_completion_auth():
# print("calling bedrock claude completion params auth")
# import os
def test_completion_bedrock_claude_completion_auth():
print("calling bedrock claude completion params auth")
import os
# aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
# aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]
# aws_region_name = os.environ["AWS_REGION_NAME"]
aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]
aws_region_name = os.environ["AWS_REGION_NAME"]
# os.environ.pop("AWS_ACCESS_KEY_ID", None)
# os.environ.pop("AWS_SECRET_ACCESS_KEY", None)
# os.environ.pop("AWS_REGION_NAME", None)
os.environ.pop("AWS_ACCESS_KEY_ID", None)
os.environ.pop("AWS_SECRET_ACCESS_KEY", None)
os.environ.pop("AWS_REGION_NAME", None)
# try:
# response = completion(
# model="bedrock/anthropic.claude-instant-v1",
# messages=messages,
# max_tokens=10,
# temperature=0.1,
# aws_access_key_id=aws_access_key_id,
# aws_secret_access_key=aws_secret_access_key,
# aws_region_name=aws_region_name,
# )
# # Add any assertions here to check the response
# print(response)
try:
response = completion(
model="bedrock/anthropic.claude-instant-v1",
messages=messages,
max_tokens=10,
temperature=0.1,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
)
# Add any assertions here to check the response
print(response)
# os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
# os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
# os.environ["AWS_REGION_NAME"] = aws_region_name
# except RateLimitError:
# pass
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
os.environ["AWS_REGION_NAME"] = aws_region_name
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# # test_completion_bedrock_claude_completion_auth()
# test_completion_bedrock_claude_completion_auth()
# def test_completion_bedrock_claude_2_1_completion_auth():
# print("calling bedrock claude 2.1 completion params auth")
# import os
def test_completion_bedrock_claude_2_1_completion_auth():
print("calling bedrock claude 2.1 completion params auth")
import os
# aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
# aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]
# aws_region_name = os.environ["AWS_REGION_NAME"]
aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]
aws_region_name = os.environ["AWS_REGION_NAME"]
# os.environ.pop("AWS_ACCESS_KEY_ID", None)
# os.environ.pop("AWS_SECRET_ACCESS_KEY", None)
# os.environ.pop("AWS_REGION_NAME", None)
# try:
# response = completion(
# model="bedrock/anthropic.claude-v2:1",
# messages=messages,
# max_tokens=10,
# temperature=0.1,
# aws_access_key_id=aws_access_key_id,
# aws_secret_access_key=aws_secret_access_key,
# aws_region_name=aws_region_name,
# )
# # Add any assertions here to check the response
# print(response)
os.environ.pop("AWS_ACCESS_KEY_ID", None)
os.environ.pop("AWS_SECRET_ACCESS_KEY", None)
os.environ.pop("AWS_REGION_NAME", None)
try:
response = completion(
model="bedrock/anthropic.claude-v2:1",
messages=messages,
max_tokens=10,
temperature=0.1,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
)
# Add any assertions here to check the response
print(response)
# os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
# os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
# os.environ["AWS_REGION_NAME"] = aws_region_name
# except RateLimitError:
# pass
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
os.environ["AWS_REGION_NAME"] = aws_region_name
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# # test_completion_bedrock_claude_2_1_completion_auth()
# test_completion_bedrock_claude_2_1_completion_auth()
# def test_completion_bedrock_claude_external_client_auth():
# print("\ncalling bedrock claude external client auth")
# import os
def test_completion_bedrock_claude_external_client_auth():
print("\ncalling bedrock claude external client auth")
import os
# aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
# aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]
# aws_region_name = os.environ["AWS_REGION_NAME"]
aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]
aws_region_name = os.environ["AWS_REGION_NAME"]
# os.environ.pop("AWS_ACCESS_KEY_ID", None)
# os.environ.pop("AWS_SECRET_ACCESS_KEY", None)
# os.environ.pop("AWS_REGION_NAME", None)
os.environ.pop("AWS_ACCESS_KEY_ID", None)
os.environ.pop("AWS_SECRET_ACCESS_KEY", None)
os.environ.pop("AWS_REGION_NAME", None)
# try:
# import boto3
try:
import boto3
# litellm.set_verbose = True
litellm.set_verbose = True
# bedrock = boto3.client(
# service_name="bedrock-runtime",
# region_name=aws_region_name,
# aws_access_key_id=aws_access_key_id,
# aws_secret_access_key=aws_secret_access_key,
# endpoint_url=f"https://bedrock-runtime.{aws_region_name}.amazonaws.com",
# )
bedrock = boto3.client(
service_name="bedrock-runtime",
region_name=aws_region_name,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
endpoint_url=f"https://bedrock-runtime.{aws_region_name}.amazonaws.com",
)
# response = completion(
# model="bedrock/anthropic.claude-instant-v1",
# messages=messages,
# max_tokens=10,
# temperature=0.1,
# aws_bedrock_client=bedrock,
# )
# # Add any assertions here to check the response
# print(response)
response = completion(
model="bedrock/anthropic.claude-instant-v1",
messages=messages,
max_tokens=10,
temperature=0.1,
aws_bedrock_client=bedrock,
)
# Add any assertions here to check the response
print(response)
# os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
# os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
# os.environ["AWS_REGION_NAME"] = aws_region_name
# except RateLimitError:
# pass
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
os.environ["AWS_REGION_NAME"] = aws_region_name
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# # test_completion_bedrock_claude_external_client_auth()
# test_completion_bedrock_claude_external_client_auth()
# @pytest.mark.skip(reason="Expired token, need to renew")
# def test_completion_bedrock_claude_sts_client_auth():
# print("\ncalling bedrock claude external client auth")
# import os
@pytest.mark.skip(reason="Expired token, need to renew")
def test_completion_bedrock_claude_sts_client_auth():
print("\ncalling bedrock claude external client auth")
import os
# aws_access_key_id = os.environ["AWS_TEMP_ACCESS_KEY_ID"]
# aws_secret_access_key = os.environ["AWS_TEMP_SECRET_ACCESS_KEY"]
# aws_region_name = os.environ["AWS_REGION_NAME"]
# aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
aws_access_key_id = os.environ["AWS_TEMP_ACCESS_KEY_ID"]
aws_secret_access_key = os.environ["AWS_TEMP_SECRET_ACCESS_KEY"]
aws_region_name = os.environ["AWS_REGION_NAME"]
aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
# try:
# import boto3
try:
import boto3
# litellm.set_verbose = True
litellm.set_verbose = True
# response = completion(
# model="bedrock/anthropic.claude-instant-v1",
# messages=messages,
# max_tokens=10,
# temperature=0.1,
# aws_region_name=aws_region_name,
# aws_access_key_id=aws_access_key_id,
# aws_secret_access_key=aws_secret_access_key,
# aws_role_name=aws_role_name,
# aws_session_name="my-test-session",
# )
response = completion(
model="bedrock/anthropic.claude-instant-v1",
messages=messages,
max_tokens=10,
temperature=0.1,
aws_region_name=aws_region_name,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_role_name=aws_role_name,
aws_session_name="my-test-session",
)
# response = embedding(
# model="cohere.embed-multilingual-v3",
# input=["hello world"],
# aws_region_name="us-east-1",
# aws_access_key_id=aws_access_key_id,
# aws_secret_access_key=aws_secret_access_key,
# aws_role_name=aws_role_name,
# aws_session_name="my-test-session",
# )
response = embedding(
model="cohere.embed-multilingual-v3",
input=["hello world"],
aws_region_name="us-east-1",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_role_name=aws_role_name,
aws_session_name="my-test-session",
)
# response = completion(
# model="gpt-3.5-turbo",
# messages=messages,
# aws_region_name="us-east-1",
# aws_access_key_id=aws_access_key_id,
# aws_secret_access_key=aws_secret_access_key,
# aws_role_name=aws_role_name,
# aws_session_name="my-test-session",
# )
# # Add any assertions here to check the response
# print(response)
# except RateLimitError:
# pass
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
response = completion(
model="gpt-3.5-turbo",
messages=messages,
aws_region_name="us-east-1",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_role_name=aws_role_name,
aws_session_name="my-test-session",
)
# Add any assertions here to check the response
print(response)
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# # test_completion_bedrock_claude_sts_client_auth()
# test_completion_bedrock_claude_sts_client_auth()
# def test_provisioned_throughput():
# try:
# litellm.set_verbose = True
# import botocore, json, io
# import botocore.session
# from botocore.stub import Stubber
# bedrock_client = botocore.session.get_session().create_client(
# "bedrock-runtime", region_name="us-east-1"
# )
# expected_params = {
# "accept": "application/json",
# "body": '{"prompt": "\\n\\nHuman: Hello, how are you?\\n\\nAssistant: ", '
# '"max_tokens_to_sample": 256}',
# "contentType": "application/json",
# "modelId": "provisioned-model-arn",
# }
# response_from_bedrock = {
# "body": io.StringIO(
# json.dumps(
# {
# "completion": " Here is a short poem about the sky:",
# "stop_reason": "max_tokens",
# "stop": None,
# }
# )
# ),
# "contentType": "contentType",
# "ResponseMetadata": {"HTTPStatusCode": 200},
# }
# with Stubber(bedrock_client) as stubber:
# stubber.add_response(
# "invoke_model",
# service_response=response_from_bedrock,
# expected_params=expected_params,
# )
# response = litellm.completion(
# model="bedrock/anthropic.claude-instant-v1",
# model_id="provisioned-model-arn",
# messages=[{"content": "Hello, how are you?", "role": "user"}],
# aws_bedrock_client=bedrock_client,
# )
# print("response stubbed", response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
def test_bedrock_claude_3():
try:
litellm.set_verbose = True
response: ModelResponse = completion(
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
messages=messages,
max_tokens=10,
)
# Add any assertions here to check the response
assert len(response.choices) > 0
assert len(response.choices[0].message.content) > 0
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# # test_provisioned_throughput()
def test_bedrock_claude_3_tool_calling():
try:
litellm.set_verbose = True
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
}
]
messages = [
{"role": "user", "content": "What's the weather like in Boston today?"}
]
response: ModelResponse = completion(
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
messages=messages,
tools=tools,
tool_choice="auto",
)
print(f"response: {response}")
# Add any assertions here to check the response
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def encode_image(image_path):
import base64
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
@pytest.mark.skip(
reason="we already test claude-3, this is just another way to pass images"
)
def test_completion_claude_3_base64():
try:
litellm.set_verbose = True
litellm.num_retries = 3
image_path = "../proxy/cached_logo.jpg"
# Getting the base64 string
base64_image = encode_image(image_path)
resp = litellm.completion(
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
"url": "data:image/jpeg;base64," + base64_image
},
},
],
}
],
)
prompt_tokens = resp.usage.prompt_tokens
raise Exception("it worked!")
except Exception as e:
if "500 Internal error encountered.'" in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")
def test_provisioned_throughput():
try:
litellm.set_verbose = True
import botocore, json, io
import botocore.session
from botocore.stub import Stubber
bedrock_client = botocore.session.get_session().create_client(
"bedrock-runtime", region_name="us-east-1"
)
expected_params = {
"accept": "application/json",
"body": '{"prompt": "\\n\\nHuman: Hello, how are you?\\n\\nAssistant: ", '
'"max_tokens_to_sample": 256}',
"contentType": "application/json",
"modelId": "provisioned-model-arn",
}
response_from_bedrock = {
"body": io.StringIO(
json.dumps(
{
"completion": " Here is a short poem about the sky:",
"stop_reason": "max_tokens",
"stop": None,
}
)
),
"contentType": "contentType",
"ResponseMetadata": {"HTTPStatusCode": 200},
}
with Stubber(bedrock_client) as stubber:
stubber.add_response(
"invoke_model",
service_response=response_from_bedrock,
expected_params=expected_params,
)
response = litellm.completion(
model="bedrock/anthropic.claude-instant-v1",
model_id="provisioned-model-arn",
messages=[{"content": "Hello, how are you?", "role": "user"}],
aws_bedrock_client=bedrock_client,
)
print("response stubbed", response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_provisioned_throughput()
def test_completion_bedrock_mistral_completion_auth():
print("calling bedrock mistral completion params auth")
import os
# aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
# aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]
# aws_region_name = os.environ["AWS_REGION_NAME"]
# os.environ.pop("AWS_ACCESS_KEY_ID", None)
# os.environ.pop("AWS_SECRET_ACCESS_KEY", None)
# os.environ.pop("AWS_REGION_NAME", None)
try:
response: ModelResponse = completion(
model="bedrock/mistral.mistral-7b-instruct-v0:2",
messages=messages,
max_tokens=10,
temperature=0.1,
)
# Add any assertions here to check the response
assert len(response.choices) > 0
assert len(response.choices[0].message.content) > 0
# os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
# os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
# os.environ["AWS_REGION_NAME"] = aws_region_name
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_bedrock_mistral_completion_auth()

View file

@ -546,7 +546,6 @@ def test_redis_cache_acompletion_stream():
# test_redis_cache_acompletion_stream()
@pytest.mark.skip(reason="AWS Suspended Account")
def test_redis_cache_acompletion_stream_bedrock():
import asyncio

View file

@ -56,7 +56,7 @@ def test_completion_custom_provider_model_name():
def test_completion_claude():
litellm.set_verbose = True
litellm.cache = None
litellm.AnthropicConfig(max_tokens_to_sample=200, metadata={"user_id": "1224"})
litellm.AnthropicTextConfig(max_tokens_to_sample=200, metadata={"user_id": "1224"})
messages = [
{
"role": "system",
@ -67,9 +67,7 @@ def test_completion_claude():
try:
# test without max tokens
response = completion(
model="claude-instant-1",
messages=messages,
request_timeout=10,
model="claude-instant-1", messages=messages, request_timeout=10
)
# Add any assertions, here to check response args
print(response)
@ -84,6 +82,126 @@ def test_completion_claude():
# test_completion_claude()
def test_completion_claude_3():
litellm.set_verbose = True
messages = [{"role": "user", "content": "Hello, world"}]
try:
# test without max tokens
response = completion(
model="anthropic/claude-3-opus-20240229",
messages=messages,
)
# Add any assertions, here to check response args
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_claude_3_function_call():
litellm.set_verbose = True
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
try:
# test without max tokens
response = completion(
model="anthropic/claude-3-opus-20240229",
messages=messages,
tools=tools,
tool_choice="auto",
)
# Add any assertions, here to check response args
print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_claude_3_stream():
litellm.set_verbose = False
messages = [{"role": "user", "content": "Hello, world"}]
try:
# test without max tokens
response = completion(
model="anthropic/claude-3-opus-20240229",
messages=messages,
max_tokens=10,
stream=True,
)
# Add any assertions, here to check response args
print(response)
for chunk in response:
print(chunk)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def encode_image(image_path):
import base64
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
@pytest.mark.skip(
reason="we already test claude-3, this is just another way to pass images"
)
def test_completion_claude_3_base64():
try:
litellm.set_verbose = True
litellm.num_retries = 3
image_path = "../proxy/cached_logo.jpg"
# Getting the base64 string
base64_image = encode_image(image_path)
resp = litellm.completion(
model="anthropic/claude-3-opus-20240229",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
"url": "data:image/jpeg;base64," + base64_image
},
},
],
}
],
)
print(f"\nResponse: {resp}")
prompt_tokens = resp.usage.prompt_tokens
raise Exception("it worked!")
except Exception as e:
if "500 Internal error encountered.'" in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")
def test_completion_mistral_api():
try:
litellm.set_verbose = True
@ -163,19 +281,17 @@ def test_completion_mistral_api_modified_input():
def test_completion_claude2_1():
try:
litellm.set_verbose = True
print("claude2.1 test request")
messages = [
{
"role": "system",
"content": "Your goal is generate a joke on the topic user gives",
"content": "Your goal is generate a joke on the topic user gives.",
},
{"role": "assistant", "content": "Hi, how can i assist you today?"},
{"role": "user", "content": "Generate a 3 liner joke for me"},
]
# test without max tokens
response = completion(
model="claude-2.1", messages=messages, request_timeout=10, max_tokens=10
)
response = completion(model="claude-2.1", messages=messages)
# Add any assertions here to check the response
print(response)
print(response.usage)
@ -1530,7 +1646,6 @@ def test_completion_chat_sagemaker_mistral():
# test_completion_chat_sagemaker_mistral()
@pytest.mark.skip(reason="AWS Suspended Account")
def test_completion_bedrock_titan_null_response():
try:
response = completion(
@ -1556,7 +1671,6 @@ def test_completion_bedrock_titan_null_response():
pytest.fail(f"An error occurred - {str(e)}")
@pytest.mark.skip(reason="AWS Suspended Account")
def test_completion_bedrock_titan():
try:
response = completion(
@ -1578,7 +1692,6 @@ def test_completion_bedrock_titan():
# test_completion_bedrock_titan()
@pytest.mark.skip(reason="AWS Suspended Account")
def test_completion_bedrock_claude():
print("calling claude")
try:
@ -1600,7 +1713,6 @@ def test_completion_bedrock_claude():
# test_completion_bedrock_claude()
@pytest.mark.skip(reason="AWS Suspended Account")
def test_completion_bedrock_cohere():
print("calling bedrock cohere")
litellm.set_verbose = True

View file

@ -171,7 +171,6 @@ def test_cost_openai_image_gen():
assert cost == 0.019922944
@pytest.mark.skip(reason="AWS Suspended Account")
def test_cost_bedrock_pricing():
"""
- get pricing specific to region for a model

View file

@ -115,4 +115,13 @@ model_list:
model_info:
description: this is a test openai model
id: 34cb2419-7c63-44ae-a189-53f1d1ce5953
model_name: test_openai_models
model_name: test_openai_models
- litellm_params:
model: amazon.titan-embed-text-v1
model_name: amazon-embeddings
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
id: 753dca9a-898d-4ff7-9961-5acf7cdf38cf
model_name: test_openai_models

View file

@ -478,7 +478,6 @@ async def test_async_chat_azure_stream():
## Test Bedrock + sync
@pytest.mark.skip(reason="AWS Suspended Account")
def test_chat_bedrock_stream():
try:
customHandler = CompletionCustomHandler()
@ -519,7 +518,6 @@ def test_chat_bedrock_stream():
## Test Bedrock + Async
@pytest.mark.skip(reason="AWS Suspended Account")
@pytest.mark.asyncio
async def test_async_chat_bedrock_stream():
try:
@ -796,7 +794,6 @@ async def test_async_embedding_azure():
## Test Bedrock + Async
@pytest.mark.skip(reason="AWS Suspended Account")
@pytest.mark.asyncio
async def test_async_embedding_bedrock():
try:

View file

@ -256,7 +256,6 @@ async def test_vertexai_aembedding():
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="AWS Suspended Account")
def test_bedrock_embedding_titan():
try:
# this tests if we support str input for bedrock embedding
@ -302,7 +301,6 @@ def test_bedrock_embedding_titan():
# test_bedrock_embedding_titan()
@pytest.mark.skip(reason="AWS Suspended Account")
def test_bedrock_embedding_cohere():
try:
litellm.set_verbose = False

View file

@ -70,7 +70,7 @@ models = ["command-nightly"]
@pytest.mark.parametrize("model", models)
def test_context_window_with_fallbacks(model):
ctx_window_fallback_dict = {
"command-nightly": "claude-2",
"command-nightly": "claude-2.1",
"gpt-3.5-turbo-instruct": "gpt-3.5-turbo-16k",
"azure/chatgpt-v-2": "gpt-3.5-turbo-16k",
}

View file

@ -121,7 +121,6 @@ async def test_async_image_generation_azure():
pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.skip(reason="AWS Suspended Account")
def test_image_generation_bedrock():
try:
litellm.set_verbose = True
@ -142,7 +141,6 @@ def test_image_generation_bedrock():
pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.skip(reason="AWS Suspended Account")
@pytest.mark.asyncio
async def test_aimage_generation_bedrock_with_optional_params():
try:

View file

@ -53,7 +53,7 @@ def claude_test_completion():
try:
# OVERRIDE WITH DYNAMIC MAX TOKENS
response_1 = litellm.completion(
model="claude-instant-1",
model="claude-instant-1.2",
messages=[{"content": "Hello, how are you?", "role": "user"}],
max_tokens=10,
)
@ -63,7 +63,7 @@ def claude_test_completion():
# USE CONFIG TOKENS
response_2 = litellm.completion(
model="claude-instant-1",
model="claude-instant-1.2",
messages=[{"content": "Hello, how are you?", "role": "user"}],
)
# Add any assertions here to check the response
@ -74,7 +74,7 @@ def claude_test_completion():
try:
response_3 = litellm.completion(
model="claude-instant-1",
model="claude-instant-1.2",
messages=[{"content": "Hello, how are you?", "role": "user"}],
n=2,
)
@ -515,7 +515,6 @@ def sagemaker_test_completion():
# Bedrock
@pytest.mark.skip(reason="AWS Suspended Account")
def bedrock_test_completion():
litellm.AmazonCohereConfig(max_tokens=10)
# litellm.set_verbose=True

View file

@ -125,7 +125,6 @@ def test_embedding(client_no_auth):
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
@pytest.mark.skip(reason="AWS Suspended Account")
def test_bedrock_embedding(client_no_auth):
global headers
from litellm.proxy.proxy_server import user_custom_auth

View file

@ -575,7 +575,6 @@ def test_azure_embedding_on_router():
# test_azure_embedding_on_router()
@pytest.mark.skip(reason="AWS Suspended Account")
def test_bedrock_on_router():
litellm.set_verbose = True
print("\n Testing bedrock on router\n")
@ -933,7 +932,7 @@ def test_router_anthropic_key_dynamic():
{
"model_name": "anthropic-claude",
"litellm_params": {
"model": "claude-instant-1",
"model": "claude-instant-1.2",
"api_key": anthropic_api_key,
},
}

View file

@ -35,7 +35,7 @@ def test_router_timeouts():
{
"model_name": "anthropic-claude-instant-1.2",
"litellm_params": {
"model": "claude-instant-1",
"model": "claude-instant-1.2",
"api_key": "os.environ/ANTHROPIC_API_KEY",
},
"tpm": 20000,
@ -87,7 +87,6 @@ def test_router_timeouts():
print("********** TOKENS USED SO FAR = ", total_tokens_used)
@pytest.mark.skip(reason="AWS Suspended Account")
@pytest.mark.asyncio
async def test_router_timeouts_bedrock():
import openai

View file

@ -348,7 +348,7 @@ def test_completion_claude_stream():
},
]
response = completion(
model="claude-instant-1", messages=messages, stream=True, max_tokens=50
model="claude-instant-1.2", messages=messages, stream=True, max_tokens=50
)
complete_response = ""
# Add any assertions here to check the response
@ -727,6 +727,7 @@ def test_completion_claude_stream_bad_key():
# pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="Replicate changed exceptions")
def test_completion_replicate_stream_bad_key():
try:
api_key = "bad-key"
@ -764,7 +765,6 @@ def test_completion_replicate_stream_bad_key():
# test_completion_replicate_stream_bad_key()
@pytest.mark.skip(reason="AWS Suspended Account")
def test_completion_bedrock_claude_stream():
try:
litellm.set_verbose = False
@ -811,7 +811,6 @@ def test_completion_bedrock_claude_stream():
# test_completion_bedrock_claude_stream()
@pytest.mark.skip(reason="AWS Suspended Account")
def test_completion_bedrock_ai21_stream():
try:
litellm.set_verbose = False
@ -1060,6 +1059,7 @@ def ai21_completion_call_bad_key():
# ai21_completion_call_bad_key()
@pytest.mark.skip(reason="flaky test")
@pytest.mark.asyncio
async def test_hf_completion_tgi_stream():
try:

View file

@ -2836,6 +2836,8 @@ def test_completion_hf_prompt_array():
print(str(e))
if "is currently loading" in str(e):
return
if "Service Unavailable" in str(e):
return
pytest.fail(f"Error occurred: {e}")

View file

@ -200,6 +200,10 @@ def map_finish_reason(
return "content_filter"
elif finish_reason == "STOP": # vertex ai
return "stop"
elif finish_reason == "end_turn" or finish_reason == "stop_sequence": # anthropic
return "stop"
elif finish_reason == "max_tokens": # anthropic
return "length"
return finish_reason
@ -241,10 +245,12 @@ class Message(OpenAIObject):
self.role = role
if function_call is not None:
self.function_call = FunctionCall(**function_call)
if tool_calls is not None:
self.tool_calls = []
for tool_call in tool_calls:
self.tool_calls.append(ChatCompletionMessageToolCall(**tool_call))
if logprobs is not None:
self._logprobs = logprobs
@ -2585,7 +2591,7 @@ def client(original_function):
if (
isinstance(e, openai.APIError)
or isinstance(e, openai.Timeout)
or isinstance(openai.APIConnectionError)
or isinstance(e, openai.APIConnectionError)
):
print_verbose(f"RETRY TRIGGERED!")
kwargs["num_retries"] = num_retries
@ -4106,6 +4112,8 @@ def get_optional_params(
and custom_llm_provider != "anyscale"
and custom_llm_provider != "together_ai"
and custom_llm_provider != "mistral"
and custom_llm_provider != "anthropic"
and custom_llm_provider != "bedrock"
):
if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
# ollama actually supports json output
@ -4186,7 +4194,15 @@ def get_optional_params(
## raise exception if provider doesn't support passed in param
if custom_llm_provider == "anthropic":
## check if unsupported param passed in
supported_params = ["stream", "stop", "temperature", "top_p", "max_tokens"]
supported_params = [
"stream",
"stop",
"temperature",
"top_p",
"max_tokens",
"tools",
"tool_choice",
]
_check_valid_arg(supported_params=supported_params)
# handle anthropic params
if stream:
@ -4200,7 +4216,14 @@ def get_optional_params(
if top_p is not None:
optional_params["top_p"] = top_p
if max_tokens is not None:
optional_params["max_tokens_to_sample"] = max_tokens
if (model == "claude-2") or (model == "claude-instant-1"):
# these models use antropic_text.py which only accepts max_tokens_to_sample
optional_params["max_tokens_to_sample"] = max_tokens
else:
optional_params["max_tokens"] = max_tokens
optional_params["max_tokens"] = max_tokens
if tools is not None:
optional_params["tools"] = tools
elif custom_llm_provider == "cohere":
## check if unsupported param passed in
supported_params = [
@ -4498,20 +4521,24 @@ def get_optional_params(
if stream:
optional_params["stream"] = stream
elif "anthropic" in model:
supported_params = ["max_tokens", "temperature", "stop", "top_p", "stream"]
supported_params = get_mapped_model_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
# anthropic params on bedrock
# \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}"
if max_tokens is not None:
optional_params["max_tokens_to_sample"] = max_tokens
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if stop is not None:
optional_params["stop_sequences"] = stop
if stream:
optional_params["stream"] = stream
if model.startswith("anthropic.claude-3"):
optional_params = (
litellm.AmazonAnthropicClaude3Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
)
else:
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
elif "amazon" in model: # amazon titan llms
supported_params = ["max_tokens", "temperature", "stop", "top_p", "stream"]
_check_valid_arg(supported_params=supported_params)
@ -4551,6 +4578,21 @@ def get_optional_params(
optional_params["temperature"] = temperature
if max_tokens is not None:
optional_params["max_tokens"] = max_tokens
elif "mistral" in model:
supported_params = ["max_tokens", "temperature", "stop", "top_p", "stream"]
_check_valid_arg(supported_params=supported_params)
# mistral params on bedrock
# \"max_tokens\":400,\"temperature\":0.7,\"top_p\":0.7,\"stop\":[\"\\\\n\\\\nHuman:\"]}"
if max_tokens is not None:
optional_params["max_tokens"] = max_tokens
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if stop is not None:
optional_params["stop"] = stop
if stream is not None:
optional_params["stream"] = stream
elif custom_llm_provider == "aleph_alpha":
supported_params = [
"max_tokens",
@ -4961,6 +5003,17 @@ def get_optional_params(
return optional_params
def get_mapped_model_params(model: str, custom_llm_provider: str):
"""
Returns the supported openai params for a given model + provider
"""
if custom_llm_provider == "bedrock":
if model.startswith("anthropic.claude-3"):
return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params()
else:
return litellm.AmazonAnthropicConfig().get_supported_openai_params()
def get_llm_provider(
model: str,
custom_llm_provider: Optional[str] = None,
@ -8017,10 +8070,21 @@ class CustomStreamWrapper:
finish_reason = None
if str_line.startswith("data:"):
data_json = json.loads(str_line[5:])
text = data_json.get("completion", "")
if data_json.get("stop_reason", None):
type_chunk = data_json.get("type", None)
if type_chunk == "content_block_delta":
"""
Anthropic content chunk
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
"""
text = data_json.get("delta", {}).get("text", "")
elif type_chunk == "message_delta":
"""
Anthropic
chunk = {'type': 'message_delta', 'delta': {'stop_reason': 'max_tokens', 'stop_sequence': None}, 'usage': {'output_tokens': 10}}
"""
# TODO - get usage from this chunk, set in response
finish_reason = data_json.get("delta", {}).get("stop_reason", None)
is_finished = True
finish_reason = data_json["stop_reason"]
return {
"text": text,
"is_finished": is_finished,
@ -8091,7 +8155,8 @@ class CustomStreamWrapper:
text = "" # don't return the final bos token
is_finished = True
finish_reason = "stop"
elif data_json.get("error", False):
raise Exception(data_json.get("error"))
return {
"text": text,
"is_finished": is_finished,
@ -8106,7 +8171,7 @@ class CustomStreamWrapper:
}
except Exception as e:
traceback.print_exc()
# raise(e)
raise e
def handle_ai21_chunk(self, chunk): # fake streaming
chunk = chunk.decode("utf-8")

View file

@ -643,6 +643,22 @@
"litellm_provider": "anthropic",
"mode": "chat"
},
"claude-3-opus-20240229": {
"max_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000075,
"litellm_provider": "anthropic",
"mode": "chat"
},
"claude-3-sonnet-20240229": {
"max_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "anthropic",
"mode": "chat"
},
"text-bison": {
"max_tokens": 8192,
"input_cost_per_token": 0.000000125,
@ -1236,6 +1252,29 @@
"litellm_provider": "bedrock",
"mode": "embedding"
},
"bedrock/us-west-2/mistral.mixtral-8x7b-instruct": {
"max_tokens": 32000,
"input_cost_per_token": 0.00000045,
"output_cost_per_token": 0.0000007,
"litellm_provider": "bedrock",
"mode": "completion"
},
"bedrock/us-west-2/mistral.mistral-7b-instruct": {
"max_tokens": 32000,
"input_cost_per_token": 0.00000015,
"output_cost_per_token": 0.0000002,
"litellm_provider": "bedrock",
"mode": "completion"
},
"anthropic.claude-3-sonnet-20240229-v1:0": {
"max_tokens": 200000,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
"mode": "chat"
},
"anthropic.claude-v1": {
"max_tokens": 100000,
"max_output_tokens": 8191,
@ -2220,4 +2259,4 @@
"mode": "embedding"
}
}
}

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "1.28.11"
version = "1.29.3"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT"
@ -74,7 +74,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api"
[tool.commitizen]
version = "1.28.11"
version = "1.29.3"
version_files = [
"pyproject.toml:^version"
]

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1 +1 @@
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-59d9232c3e7a8be6.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-a85b2c176012d8e5.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-e1b183dda365ec86.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-9b4fb13a7db53edf.js" async="" crossorigin=""></script><title>🚅 LiteLLM</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-59d9232c3e7a8be6.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/32e93a3d13512de5.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[56239,[\"730\",\"static/chunks/730-1411b729a1c79695.js\",\"931\",\"static/chunks/app/page-37bd7c3d0bb898a3.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/32e93a3d13512de5.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"p1zjZBLDqxGf-NaFvZkeF\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"🚅 LiteLLM\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-59d9232c3e7a8be6.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-a85b2c176012d8e5.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-e1b183dda365ec86.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-9b4fb13a7db53edf.js" async="" crossorigin=""></script><title>🚅 LiteLLM</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-59d9232c3e7a8be6.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/32e93a3d13512de5.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[57492,[\"730\",\"static/chunks/730-1411b729a1c79695.js\",\"931\",\"static/chunks/app/page-2ed0bc91ffef505b.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/32e93a3d13512de5.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"ZF-EluyKCEJoZptE3dOXT\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"🚅 LiteLLM\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>

View file

@ -1,7 +1,7 @@
2:I[77831,[],""]
3:I[56239,["730","static/chunks/730-1411b729a1c79695.js","931","static/chunks/app/page-37bd7c3d0bb898a3.js"],""]
3:I[57492,["730","static/chunks/730-1411b729a1c79695.js","931","static/chunks/app/page-2ed0bc91ffef505b.js"],""]
4:I[5613,[],""]
5:I[31778,[],""]
0:["p1zjZBLDqxGf-NaFvZkeF",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/32e93a3d13512de5.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
0:["ZF-EluyKCEJoZptE3dOXT",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/32e93a3d13512de5.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
6:[["$","meta","0",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","1",{"charSet":"utf-8"}],["$","title","2",{"children":"🚅 LiteLLM"}],["$","meta","3",{"name":"description","content":"LiteLLM Proxy Admin UI"}],["$","link","4",{"rel":"icon","href":"/ui/favicon.ico","type":"image/x-icon","sizes":"16x16"}],["$","meta","5",{"name":"next-size-adjust"}]]
1:null

View file

@ -11,7 +11,8 @@ import {
Metric,
Grid,
} from "@tremor/react";
import { modelInfoCall, userGetRequesedtModelsCall } from "./networking";
import { modelInfoCall, userGetRequesedtModelsCall, modelMetricsCall } from "./networking";
import { BarChart } from "@tremor/react";
import { Badge, BadgeDelta, Button } from "@tremor/react";
import RequestAccess from "./request_model_access";
import { Typography } from "antd";
@ -30,6 +31,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
userID,
}) => {
const [modelData, setModelData] = useState<any>({ data: [] });
const [modelMetrics, setModelMetrics] = useState<any[]>([]);
const [pendingRequests, setPendingRequests] = useState<any[]>([]);
useEffect(() => {
@ -47,6 +49,15 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
console.log("Model data response:", modelDataResponse.data);
setModelData(modelDataResponse);
const modelMetricsResponse = await modelMetricsCall(
accessToken,
userID,
userRole
);
console.log("Model metrics response:", modelMetricsResponse);
setModelMetrics(modelMetricsResponse);
// if userRole is Admin, show the pending requests
if (userRole === "Admin" && accessToken) {
const user_requests = await userGetRequesedtModelsCall(accessToken);
@ -75,8 +86,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
// loop through model data and edit each row
for (let i = 0; i < modelData.data.length; i++) {
let curr_model = modelData.data[i];
let litellm_model_name = curr_model?.litellm_params?.model;
let litellm_model_name = curr_model?.litellm_params?.mode
let model_info = curr_model?.model_info;
let defaultProvider = "openai";
@ -109,6 +119,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
modelData.data[i].input_cost = input_cost;
modelData.data[i].output_cost = output_cost;
modelData.data[i].max_tokens = max_tokens;
modelData.data[i].api_base = curr_model?.litellm_params?.api_base;
all_models_on_proxy.push(curr_model.model_name);
@ -141,6 +152,14 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
<TableCell>
<Title>Provider</Title>
</TableCell>
{
userRole === "Admin" && (
<TableCell>
<Title>API Base</Title>
</TableCell>
)
}
<TableCell>
<Title>Access</Title>
</TableCell>
@ -162,6 +181,11 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
<Title>{model.model_name}</Title>
</TableCell>
<TableCell>{model.provider}</TableCell>
{
userRole === "Admin" && (
<TableCell>{model.api_base}</TableCell>
)
}
<TableCell>
{model.user_access ? (
@ -183,7 +207,18 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
</TableBody>
</Table>
</Card>
{userRole === "Admin" &&
<Card>
<Title>Model Statistics (Number Requests, Latency)</Title>
<BarChart
data={modelMetrics}
index="model"
categories={["num_requests", "avg_latency_seconds"]}
colors={["blue", "red"]}
yAxisWidth={100}
tickGap={5}
/>
</Card>
{/* {userRole === "Admin" &&
pendingRequests &&
pendingRequests.length > 0 ? (
<Card>
@ -229,7 +264,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
</TableBody>
</Table>
</Card>
) : null}
) : null} */}
</Grid>
</div>
);

View file

@ -242,6 +242,41 @@ export const modelInfoCall = async (
}
};
export const modelMetricsCall = async (
accessToken: String,
userID: String,
userRole: String
) => {
/**
* Get all models on proxy
*/
try {
let url = proxyBaseUrl ? `${proxyBaseUrl}/model/metrics` : `/model/metrics`;
// message.info("Requesting model data");
const response = await fetch(url, {
method: "GET",
headers: {
Authorization: `Bearer ${accessToken}`,
"Content-Type": "application/json",
},
});
if (!response.ok) {
const errorData = await response.text();
message.error(errorData);
throw new Error("Network response was not ok");
}
const data = await response.json();
// message.info("Received model data");
return data;
// Handle success - you might want to update some state or UI based on the created key
} catch (error) {
console.error("Failed to create key:", error);
throw error;
}
};
export const modelAvailableCall = async (
accessToken: String,
userID: String,