mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(invoke_handler.py): support cache token tracking on converse streaming
This commit is contained in:
parent
96bba9354e
commit
0af6cde994
1 changed files with 5 additions and 7 deletions
|
@ -72,6 +72,9 @@ _response_stream_shape_cache = None
|
||||||
bedrock_tool_name_mappings: InMemoryCache = InMemoryCache(
|
bedrock_tool_name_mappings: InMemoryCache = InMemoryCache(
|
||||||
max_size_in_memory=50, default_ttl=600
|
max_size_in_memory=50, default_ttl=600
|
||||||
)
|
)
|
||||||
|
from litellm.llms.bedrock.chat.converse_transformation import AmazonConverseConfig
|
||||||
|
|
||||||
|
converse_config = AmazonConverseConfig()
|
||||||
|
|
||||||
|
|
||||||
class AmazonCohereChatConfig:
|
class AmazonCohereChatConfig:
|
||||||
|
@ -1274,7 +1277,7 @@ class AWSEventStreamDecoder:
|
||||||
text = ""
|
text = ""
|
||||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||||
finish_reason = ""
|
finish_reason = ""
|
||||||
usage: Optional[ChatCompletionUsageBlock] = None
|
usage: Optional[Usage] = None
|
||||||
provider_specific_fields: dict = {}
|
provider_specific_fields: dict = {}
|
||||||
reasoning_content: Optional[str] = None
|
reasoning_content: Optional[str] = None
|
||||||
thinking_blocks: Optional[List[ChatCompletionThinkingBlock]] = None
|
thinking_blocks: Optional[List[ChatCompletionThinkingBlock]] = None
|
||||||
|
@ -1350,12 +1353,7 @@ class AWSEventStreamDecoder:
|
||||||
elif "stopReason" in chunk_data:
|
elif "stopReason" in chunk_data:
|
||||||
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
|
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
|
||||||
elif "usage" in chunk_data:
|
elif "usage" in chunk_data:
|
||||||
usage = ChatCompletionUsageBlock(
|
usage = converse_config._transform_usage(chunk_data.get("usage", {}))
|
||||||
prompt_tokens=chunk_data.get("inputTokens", 0),
|
|
||||||
completion_tokens=chunk_data.get("outputTokens", 0),
|
|
||||||
total_tokens=chunk_data.get("totalTokens", 0),
|
|
||||||
)
|
|
||||||
|
|
||||||
model_response_provider_specific_fields = {}
|
model_response_provider_specific_fields = {}
|
||||||
if "trace" in chunk_data:
|
if "trace" in chunk_data:
|
||||||
trace = chunk_data.get("trace")
|
trace = chunk_data.get("trace")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue