LiteLLM Minor fixes + improvements (08/03/2024) (#5488)

* fix(internal_user_endpoints.py): set budget_reset_at for /user/update

* fix(vertex_and_google_ai_studio_gemini.py): handle accumulated json

Fixes https://github.com/BerriAI/litellm/issues/5479

* fix(vertex_ai_and_gemini.py): fix assistant message function call when content is not None

Fixes https://github.com/BerriAI/litellm/issues/5490

* fix(proxy_server.py): generic state uuid for okta sso

* fix(lago.py): improve debug logs

Debugging for https://github.com/BerriAI/litellm/issues/5477

* docs(bedrock.md): add bedrock cross-region inferencing to docs

* fix(azure.py): return azure response headers on aembedding call

* feat(azure.py): return azure response headers for `/audio/transcription`

* fix(types/utils.py): standardize deepseek / anthropic prompt caching usage information

Closes https://github.com/BerriAI/litellm/issues/5285

* docs(usage.md): add docs on litellm usage object

* test(test_completion.py): mark flaky test
This commit is contained in:
Krish Dholakia 2024-09-03 21:21:34 -07:00 committed by GitHub
parent 59042511c9
commit be3c7b401e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 736 additions and 81 deletions

View file

@ -0,0 +1,175 @@
# Usage
LiteLLM returns the OpenAI compatible usage object across all providers.
```bash
"usage": {
"prompt_tokens": int,
"completion_tokens": int,
"total_tokens": int
}
```
## Quick Start
```python
from litellm import completion
import os
## set ENV variables
os.environ["OPENAI_API_KEY"] = "your-api-key"
response = completion(
model="gpt-3.5-turbo",
messages=[{ "content": "Hello, how are you?","role": "user"}]
)
print(response.usage)
```
## Streaming Usage
if `stream_options={"include_usage": True}` is set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
```python
from litellm import completion
completion = completion(
model="gpt-4o",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"}
],
stream=True,
stream_options={"include_usage": True}
)
for chunk in completion:
print(chunk.choices[0].delta)
```
## Prompt Caching
For Anthropic + Deepseek, LiteLLM follows the Anthropic prompt caching usage object format:
```bash
"usage": {
"prompt_tokens": int,
"completion_tokens": int,
"total_tokens": int,
"_cache_creation_input_tokens": int, # hidden param for prompt caching. Might change, once openai introduces their equivalent.
"_cache_read_input_tokens": int # hidden param for prompt caching. Might change, once openai introduces their equivalent.
}
```
- `prompt_tokens`: These are the non-cached prompt tokens (same as Anthropic, equivalent to Deepseek `prompt_cache_miss_tokens`).
- `completion_tokens`: These are the output tokens generated by the model.
- `total_tokens`: Sum of prompt_tokens + completion_tokens.
- `_cache_creation_input_tokens`: Input tokens that were written to cache. (Anthropic only).
- `_cache_read_input_tokens`: Input tokens that were read from cache for that call. (equivalent to Deepseek `prompt_cache_hit_tokens`).
### Anthropic Example
```python
from litellm import completion
import litellm
import os
litellm.set_verbose = True # 👈 SEE RAW REQUEST
os.environ["ANTHROPIC_API_KEY"] = ""
response = completion(
model="anthropic/claude-3-5-sonnet-20240620",
messages=[
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an AI assistant tasked with analyzing legal documents.",
},
{
"type": "text",
"text": "Here is the full text of a complex legal agreement" * 400,
"cache_control": {"type": "ephemeral"},
},
],
},
{
"role": "user",
"content": "what are the key terms and conditions in this agreement?",
},
]
)
print(response.usage)
```
### Deepeek Example
```python
from litellm import completion
import litellm
import os
os.environ["DEEPSEEK_API_KEY"] = ""
litellm.set_verbose = True # 👈 SEE RAW REQUEST
model_name = "deepseek/deepseek-chat"
messages_1 = [
{
"role": "system",
"content": "You are a history expert. The user will provide a series of questions, and your answers should be concise and start with `Answer:`",
},
{
"role": "user",
"content": "In what year did Qin Shi Huang unify the six states?",
},
{"role": "assistant", "content": "Answer: 221 BC"},
{"role": "user", "content": "Who was the founder of the Han Dynasty?"},
{"role": "assistant", "content": "Answer: Liu Bang"},
{"role": "user", "content": "Who was the last emperor of the Tang Dynasty?"},
{"role": "assistant", "content": "Answer: Li Zhu"},
{
"role": "user",
"content": "Who was the founding emperor of the Ming Dynasty?",
},
{"role": "assistant", "content": "Answer: Zhu Yuanzhang"},
{
"role": "user",
"content": "Who was the founding emperor of the Qing Dynasty?",
},
]
message_2 = [
{
"role": "system",
"content": "You are a history expert. The user will provide a series of questions, and your answers should be concise and start with `Answer:`",
},
{
"role": "user",
"content": "In what year did Qin Shi Huang unify the six states?",
},
{"role": "assistant", "content": "Answer: 221 BC"},
{"role": "user", "content": "Who was the founder of the Han Dynasty?"},
{"role": "assistant", "content": "Answer: Liu Bang"},
{"role": "user", "content": "Who was the last emperor of the Tang Dynasty?"},
{"role": "assistant", "content": "Answer: Li Zhu"},
{
"role": "user",
"content": "Who was the founding emperor of the Ming Dynasty?",
},
{"role": "assistant", "content": "Answer: Zhu Yuanzhang"},
{"role": "user", "content": "When did the Shang Dynasty fall?"},
]
response_1 = litellm.completion(model=model_name, messages=messages_1)
response_2 = litellm.completion(model=model_name, messages=message_2)
# Add any assertions here to check the response
print(response_2.usage)
```

View file

@ -577,6 +577,135 @@ for chunk in response:
}
```
## Cross-region inferencing
LiteLLM supports Bedrock [cross-region inferencing](https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference.html) across all [supported bedrock models](https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-support.html).
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
os.environ["AWS_REGION_NAME"] = ""
litellm.set_verbose = True # 👈 SEE RAW REQUEST
response = completion(
model="bedrock/us.anthropic.claude-3-haiku-20240307-v1:0",
messages=messages,
max_tokens=10,
temperature=0.1,
)
print("Final Response: {}".format(response))
```
</TabItem>
<TabItem value="proxy" label="PROXY">
#### 1. Setup config.yaml
```yaml
model_list:
- model_name: bedrock-claude-haiku
litellm_params:
model: bedrock/us.anthropic.claude-3-haiku-20240307-v1:0
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
```
#### 2. Start the proxy
```bash
litellm --config /path/to/config.yaml
```
#### 3. Test it
<Tabs>
<TabItem value="Curl" label="Curl Request">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "bedrock-claude-haiku",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
]
}
'
```
</TabItem>
<TabItem value="openai" label="OpenAI v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="bedrock-claude-haiku", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
])
print(response)
```
</TabItem>
<TabItem value="langchain" label="Langchain">
```python
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import HumanMessage, SystemMessage
chat = ChatOpenAI(
openai_api_base="http://0.0.0.0:4000", # set openai_api_base to the LiteLLM Proxy
model = "bedrock-claude-haiku",
temperature=0.1
)
messages = [
SystemMessage(
content="You are a helpful assistant that im using to make a test request to."
),
HumanMessage(
content="test from litellm. tell me why it's amazing in 1 sentence"
),
]
response = chat(messages)
print(response)
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>
## Alternate user/assistant messages
Use `user_continue_message` to add a default user message, for cases (e.g. Autogen) where the client might not follow alternating user/assistant messages starting and ending with a user message.

View file

@ -100,6 +100,7 @@ GENERIC_CLIENT_SECRET = "<your-okta-client-secret>"
GENERIC_AUTHORIZATION_ENDPOINT = "<your-okta-domain>/authorize" # https://dev-2kqkcd6lx6kdkuzt.us.auth0.com/authorize
GENERIC_TOKEN_ENDPOINT = "<your-okta-domain>/token" # https://dev-2kqkcd6lx6kdkuzt.us.auth0.com/oauth/token
GENERIC_USERINFO_ENDPOINT = "<your-okta-domain>/userinfo" # https://dev-2kqkcd6lx6kdkuzt.us.auth0.com/userinfo
GENERIC_CLIENT_STATE = "random-string" # [OPTIONAL] REQUIRED BY OKTA, if not set random state value is generated
```
You can get your domain specific auth/token/userinfo endpoints at `<YOUR-OKTA-DOMAIN>/.well-known/openid-configuration`

View file

@ -185,6 +185,7 @@ const sidebars = {
"completion/drop_params",
"completion/prompt_formatting",
"completion/output",
"completion/usage",
"exception_mapping",
"completion/stream",
"completion/message_trimming",

View file

@ -105,9 +105,13 @@ class LagoLogger(CustomLogger):
external_customer_id = user_id
if external_customer_id is None:
raise Exception("External Customer ID is not set")
raise Exception(
"External Customer ID is not set. Charge_by={}. User_id={}. End_user_id={}. Team_id={}".format(
charge_by, user_id, end_user_id, team_id
)
)
return {
returned_val = {
"event": {
"transaction_id": str(uuid.uuid4()),
"external_customer_id": external_customer_id,
@ -116,6 +120,11 @@ class LagoLogger(CustomLogger):
}
}
verbose_logger.debug(
"\033[91mLogged Lago Object:\n{}\033[0m\n".format(returned_val)
)
return returned_val
def log_success_event(self, kwargs, response_obj, start_time, end_time):
_url = os.getenv("LAGO_API_BASE")
assert _url is not None and isinstance(

View file

@ -1033,6 +1033,7 @@ class AzureChatCompletion(BaseLLM):
raw_response = await openai_aclient.embeddings.with_raw_response.create(
**data, timeout=timeout
)
headers = dict(raw_response.headers)
response = raw_response.parse()
stringified_response = response.model_dump()
## LOGGING
@ -1045,6 +1046,8 @@ class AzureChatCompletion(BaseLLM):
return convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
hidden_params={"headers": headers},
_response_headers=headers,
response_type="embedding",
)
except Exception as e:
@ -1606,7 +1609,7 @@ class AzureChatCompletion(BaseLLM):
if max_retries is not None:
azure_client_params["max_retries"] = max_retries
if atranscription == True:
if atranscription is True:
return self.async_audio_transcriptions(
audio_file=audio_file,
data=data,
@ -1693,10 +1696,15 @@ class AzureChatCompletion(BaseLLM):
},
)
response = await async_azure_client.audio.transcriptions.create(
**data, timeout=timeout
raw_response = (
await async_azure_client.audio.transcriptions.with_raw_response.create(
**data, timeout=timeout
)
) # type: ignore
headers = dict(raw_response.headers)
response = raw_response.parse()
if isinstance(response, BaseModel):
stringified_response = response.model_dump()
else:
@ -1717,7 +1725,13 @@ class AzureChatCompletion(BaseLLM):
original_response=stringified_response,
)
hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"}
response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
response = convert_to_model_response_object(
_response_headers=headers,
response_object=stringified_response,
model_response_object=model_response,
hidden_params=hidden_params,
response_type="audio_transcription",
) # type: ignore
return response
except Exception as e:
## LOGGING

View file

@ -768,7 +768,8 @@ async def make_call(
def make_sync_call(
client: Optional[HTTPHandler],
client: Optional[HTTPHandler], # module-level client
gemini_client: Optional[HTTPHandler], # if passed by user
api_base: str,
headers: dict,
data: str,
@ -776,6 +777,8 @@ def make_sync_call(
messages: list,
logging_obj,
):
if gemini_client is not None:
client = gemini_client
if client is None:
client = HTTPHandler() # Create a new client if none provided
@ -1061,10 +1064,17 @@ class VertexLLM(BaseLLM):
os.getcwd(),
)
if os.path.exists(credentials):
json_obj = json.load(open(credentials))
else:
json_obj = json.loads(credentials)
try:
if os.path.exists(credentials):
json_obj = json.load(open(credentials))
else:
json_obj = json.loads(credentials)
except Exception:
raise Exception(
"Unable to load vertex credentials from environment. Got={}".format(
credentials
)
)
# Check if the JSON object contains Workload Identity Federation configuration
if "type" in json_obj and json_obj["type"] == "external_account":
@ -1438,7 +1448,11 @@ class VertexLLM(BaseLLM):
completion_stream=None,
make_call=partial(
make_sync_call,
client=None,
gemini_client=(
client
if client is not None and isinstance(client, HTTPHandler)
else None
),
api_base=url,
data=request_data_str,
model=model,
@ -1491,6 +1505,9 @@ class VertexLLM(BaseLLM):
class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool):
self.streaming_response = streaming_response
self.chunk_type: Literal["valid_json", "accumulated_json"] = "valid_json"
self.accumulated_json = ""
self.sent_first_chunk = False
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
@ -1560,29 +1577,80 @@ class ModelResponseIterator:
self.response_iterator = self.streaming_response
return self
def handle_valid_json_chunk(self, chunk: str) -> GenericStreamingChunk:
chunk = chunk.strip()
try:
json_chunk = json.loads(chunk)
except json.JSONDecodeError as e:
if (
self.sent_first_chunk is False
): # only check for accumulated json, on first chunk, else raise error. Prevent real errors from being masked.
self.chunk_type = "accumulated_json"
return self.handle_accumulated_json_chunk(chunk=chunk)
raise e
if self.sent_first_chunk is False:
self.sent_first_chunk = True
return self.chunk_parser(chunk=json_chunk)
def handle_accumulated_json_chunk(self, chunk: str) -> GenericStreamingChunk:
message = chunk.replace("data:", "").replace("\n\n", "")
# Accumulate JSON data
self.accumulated_json += message
# Try to parse the accumulated JSON
try:
_data = json.loads(self.accumulated_json)
self.accumulated_json = "" # reset after successful parsing
return self.chunk_parser(chunk=_data)
except json.JSONDecodeError:
# If it's not valid JSON yet, continue to the next event
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
def _common_chunk_parsing_logic(self, chunk: str) -> GenericStreamingChunk:
chunk = chunk.replace("data:", "")
if len(chunk) > 0:
"""
Check if initial chunk valid json
- if partial json -> enter accumulated json logic
- if valid - continue
"""
if self.chunk_type == "valid_json":
return self.handle_valid_json_chunk(chunk=chunk)
elif self.chunk_type == "accumulated_json":
return self.handle_accumulated_json_chunk(chunk=chunk)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
def __next__(self):
try:
chunk = self.response_iterator.__next__()
except StopIteration:
if self.chunk_type == "accumulated_json" and self.accumulated_json:
return self.handle_accumulated_json_chunk(chunk="")
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
chunk = chunk.replace("data:", "")
chunk = chunk.strip()
if len(chunk) > 0:
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
return self._common_chunk_parsing_logic(chunk=chunk)
except StopIteration:
raise StopIteration
except ValueError as e:
@ -1597,25 +1665,14 @@ class ModelResponseIterator:
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
if self.chunk_type == "accumulated_json" and self.accumulated_json:
return self.handle_accumulated_json_chunk(chunk="")
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
chunk = chunk.replace("data:", "")
chunk = chunk.strip()
if len(chunk) > 0:
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
return self._common_chunk_parsing_logic(chunk=chunk)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:

View file

@ -25,7 +25,7 @@ from litellm.types.files import (
is_gemini_1_5_accepted_file_type,
is_video_file_type,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage
from litellm.types.llms.vertex_ai import *
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
@ -177,38 +177,34 @@ def _gemini_convert_messages_with_history(
assistant_content = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
if messages[msg_i].get("content", None) is not None and isinstance(
messages[msg_i]["content"], list
assistant_msg = ChatCompletionAssistantMessage(**messages[msg_i]) # type: ignore
if assistant_msg.get("content", None) is not None and isinstance(
assistant_msg["content"], list
):
_parts = []
for element in messages[msg_i]["content"]: # type: ignore
for element in assistant_msg["content"]:
if isinstance(element, dict):
if element["type"] == "text":
_part = PartType(text=element["text"]) # type: ignore
_parts.append(_part)
elif element["type"] == "image_url":
image_url = element["image_url"]["url"] # type: ignore
_part = _process_gemini_image(image_url=image_url)
_parts.append(_part) # type: ignore
assistant_content.extend(_parts)
elif (
messages[msg_i].get("content", None) is not None
and isinstance(messages[msg_i]["content"], str)
and messages[msg_i]["content"]
assistant_msg.get("content", None) is not None
and isinstance(assistant_msg["content"], str)
and assistant_msg["content"]
):
assistant_text = messages[msg_i]["content"] # either string or none
assistant_text = assistant_msg["content"] # either string or none
assistant_content.append(PartType(text=assistant_text)) # type: ignore
elif messages[msg_i].get(
"tool_calls", []
## HANDLE ASSISTANT FUNCTION CALL
if (
assistant_msg.get("tool_calls", []) is not None
or assistant_msg.get("function_call") is not None
): # support assistant tool invoke conversion
assistant_content.extend(
convert_to_gemini_tool_call_invoke(messages[msg_i]) # type: ignore
)
last_message_with_tool_calls = messages[msg_i]
elif messages[msg_i].get("function_call") is not None:
assistant_content.extend(
convert_to_gemini_tool_call_invoke(messages[msg_i]) # type: ignore
convert_to_gemini_tool_call_invoke(assistant_msg)
)
last_message_with_tool_calls = assistant_msg
msg_i += 1

View file

@ -1,12 +1,6 @@
model_list:
- model_name: "batch-gpt-4o-mini"
- model_name: "whisper"
litellm_params:
model: "azure/gpt-4o-mini"
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE
model_info:
mode: batch
litellm_settings:
enable_loadbalancing_on_batch_endpoints: true
model: "azure/azure-whisper"
api_key: os.environ/AZURE_EUROPE_API_KEY
api_base: "https://my-endpoint-europe-berri-992.openai.azure.com/"

View file

@ -529,6 +529,13 @@ async def user_update(
non_default_values["budget_duration"] = (
litellm.internal_user_budget_duration
)
duration_s = _duration_in_seconds(
duration=non_default_values["budget_duration"]
)
user_reset_at = datetime.now(timezone.utc) + timedelta(
seconds=duration_s
)
non_default_values["budget_reset_at"] = user_reset_at
## ADD USER, IF NEW ##
verbose_proxy_logger.debug("/user/update: Received data = %s", data)

View file

@ -3703,6 +3703,7 @@ async def embeddings(
api_base = hidden_params.get("api_base", None) or ""
response_cost = hidden_params.get("response_cost", None) or ""
litellm_call_id = hidden_params.get("litellm_call_id", None) or ""
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
fastapi_response.headers.update(
get_custom_headers(
@ -3715,6 +3716,7 @@ async def embeddings(
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
call_id=litellm_call_id,
request_data=data,
**additional_headers,
)
)
await check_response_size_is_safe(response=response)
@ -4090,6 +4092,7 @@ async def audio_transcriptions(
api_base = hidden_params.get("api_base", None) or ""
response_cost = hidden_params.get("response_cost", None) or ""
litellm_call_id = hidden_params.get("litellm_call_id", None) or ""
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
fastapi_response.headers.update(
get_custom_headers(
@ -4102,6 +4105,7 @@ async def audio_transcriptions(
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
call_id=litellm_call_id,
request_data=data,
**additional_headers,
)
)
@ -8019,8 +8023,13 @@ async def google_login(request: Request):
# SSO providers do not allow stateless verification
redirect_params = {}
state = os.getenv("GENERIC_CLIENT_STATE", None)
if state:
redirect_params["state"] = state
elif "okta" in generic_authorization_endpoint:
redirect_params["state"] = (
uuid.uuid4().hex
) # set state param for okta - required
return await generic_sso.get_login_redirect(**redirect_params) # type: ignore
elif ui_username is not None:
# No Google, Microsoft SSO

View file

@ -2757,3 +2757,69 @@ def test_gemini_function_call_parameter_in_messages():
"toolConfig": {"functionCallingConfig": {"mode": "AUTO"}},
"generationConfig": {},
} == mock_client.call_args.kwargs["json"]
def test_gemini_function_call_parameter_in_messages_2():
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
_gemini_convert_messages_with_history,
)
messages = [
{"role": "user", "content": "search for weather in boston (use `search`)"},
{
"role": "assistant",
"content": "Sure, let me check.",
"function_call": {
"name": "search",
"arguments": '{"queries": ["weather in boston"]}',
},
},
{
"role": "function",
"name": "search",
"content": "The weather in Boston is 100 degrees.",
},
]
returned_contents = _gemini_convert_messages_with_history(messages=messages)
assert returned_contents == [
{
"role": "user",
"parts": [{"text": "search for weather in boston (use `search`)"}],
},
{
"role": "model",
"parts": [
{"text": "Sure, let me check."},
{
"function_call": {
"name": "search",
"args": {
"fields": {
"key": "queries",
"value": {"list_value": ["weather in boston"]},
}
},
}
},
],
},
{
"parts": [
{
"function_response": {
"name": "search",
"response": {
"fields": {
"key": "content",
"value": {
"string_value": "The weather in Boston is 100 degrees."
},
}
},
}
}
]
},
]

View file

@ -1222,3 +1222,13 @@ def test_not_found_error():
}
],
)
def test_bedrock_cross_region_inference():
litellm.set_verbose = True
response = completion(
model="bedrock/us.anthropic.claude-3-haiku-20240307-v1:0",
messages=messages,
max_tokens=10,
temperature=0.1,
)

View file

@ -2180,6 +2180,7 @@ def test_completion_openai():
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
],
)
@pytest.mark.flaky(retries=3, delay=1)
def test_completion_openai_pydantic(model):
try:
litellm.set_verbose = True

View file

@ -1020,14 +1020,66 @@ def test_completion_cost_anthropic():
def test_completion_cost_deepseek():
litellm.set_verbose = True
model_name = "deepseek/deepseek-chat"
messages = [{"role": "user", "content": "Hey, how's it going?"}]
messages_1 = [
{
"role": "system",
"content": "You are a history expert. The user will provide a series of questions, and your answers should be concise and start with `Answer:`",
},
{
"role": "user",
"content": "In what year did Qin Shi Huang unify the six states?",
},
{"role": "assistant", "content": "Answer: 221 BC"},
{"role": "user", "content": "Who was the founder of the Han Dynasty?"},
{"role": "assistant", "content": "Answer: Liu Bang"},
{"role": "user", "content": "Who was the last emperor of the Tang Dynasty?"},
{"role": "assistant", "content": "Answer: Li Zhu"},
{
"role": "user",
"content": "Who was the founding emperor of the Ming Dynasty?",
},
{"role": "assistant", "content": "Answer: Zhu Yuanzhang"},
{
"role": "user",
"content": "Who was the founding emperor of the Qing Dynasty?",
},
]
message_2 = [
{
"role": "system",
"content": "You are a history expert. The user will provide a series of questions, and your answers should be concise and start with `Answer:`",
},
{
"role": "user",
"content": "In what year did Qin Shi Huang unify the six states?",
},
{"role": "assistant", "content": "Answer: 221 BC"},
{"role": "user", "content": "Who was the founder of the Han Dynasty?"},
{"role": "assistant", "content": "Answer: Liu Bang"},
{"role": "user", "content": "Who was the last emperor of the Tang Dynasty?"},
{"role": "assistant", "content": "Answer: Li Zhu"},
{
"role": "user",
"content": "Who was the founding emperor of the Ming Dynasty?",
},
{"role": "assistant", "content": "Answer: Zhu Yuanzhang"},
{"role": "user", "content": "When did the Shang Dynasty fall?"},
]
try:
response_1 = litellm.completion(model=model_name, messages=messages)
response_2 = litellm.completion(model=model_name, messages=messages)
response_1 = litellm.completion(model=model_name, messages=messages_1)
response_2 = litellm.completion(model=model_name, messages=message_2)
# Add any assertions here to check the response
print(response_2)
assert response_2.usage.prompt_cache_hit_tokens is not None
assert response_2.usage.prompt_cache_miss_tokens is not None
assert (
response_2.usage.prompt_tokens == response_2.usage.prompt_cache_miss_tokens
)
assert (
response_2.usage._cache_read_input_tokens
== response_2.usage.prompt_cache_hit_tokens
)
except litellm.APIError as e:
pass
except Exception as e:

View file

@ -9,6 +9,7 @@ import time
import traceback
import uuid
from typing import Tuple
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import BaseModel
@ -832,6 +833,129 @@ async def test_completion_gemini_stream(sync_mode):
# asyncio.run(test_acompletion_gemini_stream())
def gemini_mock_post_streaming(url, **kwargs):
# This generator simulates the streaming response with partial JSON content
def stream_response():
chunks = [
"{",
'"candidates": [{"content": {"parts": [{"text": "Twelve"}],"role": "model"},"finishReason": "STOP","index": 0}],"usageMetadata": {"promptTokenCount": 8,"candidatesTokenCount": 1,"totalTokenCount": 9',
"}}\n\n", # This is the continuation of the previous chunk
'data: {"candidates": [{"content": {"parts": [{"text": "-year-old Finn was never one for adventure. He preferred the comfort of',
' his room, his nose buried in a book, to the chaotic world outside."}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 8,"candidatesTokenCount": 17,"totalTokenCount": 25}}\n\n',
# Add more chunks as needed
]
for chunk in chunks:
yield chunk
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "text/event-stream"}
mock_response.iter_lines = MagicMock(return_value=stream_response())
return mock_response
@pytest.mark.parametrize(
"sync_mode",
[True],
) # ,
@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_completion_gemini_stream_accumulated_json(sync_mode):
try:
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
litellm.set_verbose = True
print("Streaming gemini response")
function1 = [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
]
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
print("testing gemini streaming")
complete_response = ""
# Add any assertions here to check the response
non_empty_chunks = 0
chunks = []
if sync_mode:
client = HTTPHandler(concurrent_limit=1)
with patch.object(
client, "post", side_effect=gemini_mock_post_streaming
) as mock_client:
response = completion(
model="gemini/gemini-1.5-flash",
messages=messages,
stream=True,
functions=function1,
client=client,
)
for idx, chunk in enumerate(response):
print(chunk)
chunks.append(chunk)
# print(chunk.choices[0].delta)
chunk, finished = streaming_format_tests(idx, chunk)
print(f"finished: {finished}")
if finished:
break
non_empty_chunks += 1
complete_response += chunk
mock_client.assert_called_once()
else:
client = AsyncHTTPHandler(concurrent_limit=1)
with patch.object(
client, "post", side_effect=gemini_mock_post_streaming
) as mock_client:
response = await litellm.acompletion(
model="gemini/gemini-1.5-flash",
messages=messages,
stream=True,
functions=function1,
)
idx = 0
async for chunk in response:
print(chunk)
chunks.append(chunk)
# print(chunk.choices[0].delta)
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
break
non_empty_chunks += 1
complete_response += chunk
idx += 1
# if complete_response.strip() == "":
# raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
assert (
complete_response
== "Twelve-year-old Finn was never one for adventure. He preferred the comfort of his room, his nose buried in a book, to the chaotic world outside."
)
# assert non_empty_chunks > 1
except litellm.InternalServerError as e:
pass
except litellm.RateLimitError as e:
pass
except Exception as e:
# if "429 Resource has been exhausted":
# return
pytest.fail(f"Error occurred: {e}")
def test_completion_mistral_api_mistral_large_function_call_with_streaming():

View file

@ -363,7 +363,7 @@ class ChatCompletionUserMessage(TypedDict):
class ChatCompletionAssistantMessage(TypedDict, total=False):
role: Required[Literal["assistant"]]
content: Optional[str]
content: Optional[Union[str, Iterable[ChatCompletionTextObject]]]
name: Optional[str]
tool_calls: Optional[List[ChatCompletionAssistantToolCall]]
function_call: Optional[ChatCompletionToolCallFunctionChunk]

View file

@ -474,6 +474,13 @@ class Usage(CompletionUsage):
total_tokens: Optional[int] = None,
**params,
):
## DEEPSEEK PROMPT TOKEN HANDLING ## - follow the anthropic format, of having prompt tokens be just the non-cached token input. Enables accurate cost-tracking - Relevant issue: https://github.com/BerriAI/litellm/issues/5285
if (
"prompt_cache_miss_tokens" in params
and isinstance(params["prompt_cache_miss_tokens"], int)
and prompt_tokens is not None
):
prompt_tokens = params["prompt_cache_miss_tokens"]
data = {
"prompt_tokens": prompt_tokens or 0,
"completion_tokens": completion_tokens or 0,
@ -481,6 +488,7 @@ class Usage(CompletionUsage):
}
super().__init__(**data)
## ANTHROPIC MAPPING ##
if "cache_creation_input_tokens" in params and isinstance(
params["cache_creation_input_tokens"], int
):
@ -491,6 +499,12 @@ class Usage(CompletionUsage):
):
self._cache_read_input_tokens = params["cache_read_input_tokens"]
## DEEPSEEK MAPPING ##
if "prompt_cache_hit_tokens" in params and isinstance(
params["prompt_cache_hit_tokens"], int
):
self._cache_read_input_tokens = params["prompt_cache_hit_tokens=0"]
for k, v in params.items():
setattr(self, k, v)

View file

@ -6146,6 +6146,7 @@ def convert_to_model_response_object(
] = None, # used for supporting 'json_schema' on older models
):
received_args = locals()
if _response_headers is not None:
llm_response_headers = {
"{}-{}".format("llm_provider", k): v for k, v in _response_headers.items()
@ -6230,13 +6231,8 @@ def convert_to_model_response_object(
model_response_object.choices = choice_list
if "usage" in response_object and response_object["usage"] is not None:
model_response_object.usage.completion_tokens = response_object["usage"].get("completion_tokens", 0) # type: ignore
model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
special_keys = ["completion_tokens", "prompt_tokens", "total_tokens"]
for k, v in response_object["usage"].items():
if k not in special_keys:
setattr(model_response_object.usage, k, v) # type: ignore
usage_object = litellm.Usage(**response_object["usage"])
setattr(model_response_object, "usage", usage_object)
if "created" in response_object:
model_response_object.created = response_object["created"] or int(
time.time()