(feat) azure stream - count correct prompt tokens

This commit is contained in:
ishaan-jaff 2023-12-29 15:14:34 +05:30
parent 1e07f0fce8
commit a300ab9152

View file

@ -2472,15 +2472,15 @@ def openai_token_counter(
)
num_tokens = 0
if text:
num_tokens = len(encoding.encode(text, disallowed_special=()))
elif messages:
if messages is not None:
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value, disallowed_special=()))
if key == "name":
num_tokens += tokens_per_name
elif text is not None:
num_tokens = len(encoding.encode(text, disallowed_special=()))
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
@ -2520,7 +2520,9 @@ def token_counter(model="", text=None, messages: Optional[List] = None):
num_tokens = len(enc.ids)
elif tokenizer_json["type"] == "openai_tokenizer":
if model in litellm.open_ai_chat_completion_models:
num_tokens = openai_token_counter(text=text, model=model)
num_tokens = openai_token_counter(
text=text, model=model, messages=messages
)
else:
enc = tokenizer_json["tokenizer"].encode(text)
num_tokens = len(enc)
@ -7255,7 +7257,11 @@ class CustomStreamWrapper:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
else: # openai chat model
else: # openai / azure chat model
if self.custom_llm_provider == "azure":
if hasattr(chunk, "model"):
# for azure, we need to pass the model from the orignal chunk
self.model = chunk.model
response_obj = self.handle_openai_chat_completion_chunk(chunk)
if response_obj == None:
return