mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
fix(utils.py): azure tool calling streaming
This commit is contained in:
parent
4cdd930fa2
commit
e8331a4647
4 changed files with 55 additions and 12 deletions
|
@ -194,7 +194,6 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
response = await azure_client.chat.completions.create(**data)
|
||||
response.model = "azure/" + str(response.model)
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||
except AzureOpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
|
|
|
@ -95,7 +95,7 @@ def test_stream_chunk_builder_litellm_tool_call():
|
|||
try:
|
||||
litellm.set_verbose = False
|
||||
response = litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
model="azure/chatgpt-functioncalling",
|
||||
messages=messages,
|
||||
tools=tools_schema,
|
||||
stream=True,
|
||||
|
|
|
@ -137,6 +137,30 @@ def streaming_format_tests(idx, chunk):
|
|||
print(f"extracted chunk: {extracted_chunk}")
|
||||
return extracted_chunk, finished
|
||||
|
||||
tools_schema = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
# def test_completion_cohere_stream():
|
||||
# # this is a flaky test due to the cohere API endpoint being unstable
|
||||
# try:
|
||||
|
@ -231,6 +255,26 @@ def test_completion_azure_stream():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_completion_azure_stream()
|
||||
|
||||
def test_completion_azure_function_calling_stream():
|
||||
try:
|
||||
litellm.set_verbose = False
|
||||
user_message = "What is the current weather in Boston?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
response = completion(
|
||||
model="azure/chatgpt-functioncalling", messages=messages, stream=True, tools=tools_schema
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
if chunk["choices"][0]["finish_reason"] == "stop":
|
||||
break
|
||||
print(chunk["choices"][0]["finish_reason"])
|
||||
print(chunk["choices"][0]["delta"]["content"])
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
test_completion_azure_function_calling_stream()
|
||||
|
||||
def test_completion_claude_stream():
|
||||
try:
|
||||
messages = [
|
||||
|
@ -347,7 +391,7 @@ def test_completion_nlp_cloud_stream():
|
|||
except Exception as e:
|
||||
print(f"Error occurred: {e}")
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
test_completion_nlp_cloud_stream()
|
||||
# test_completion_nlp_cloud_stream()
|
||||
|
||||
def test_completion_claude_stream_bad_key():
|
||||
try:
|
||||
|
|
|
@ -4933,7 +4933,6 @@ class CustomStreamWrapper:
|
|||
is_finished = True
|
||||
finish_reason = str_line.choices[0].finish_reason
|
||||
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
|
@ -5173,7 +5172,7 @@ class CustomStreamWrapper:
|
|||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||
if response_obj["is_finished"]:
|
||||
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
||||
|
||||
|
||||
model_response.model = self.model
|
||||
print_verbose(f"model_response: {model_response}; completion_obj: {completion_obj}")
|
||||
print_verbose(f"model_response finish reason 3: {model_response.choices[0].finish_reason}")
|
||||
|
@ -5196,11 +5195,14 @@ class CustomStreamWrapper:
|
|||
# enter this branch when no content has been passed in response
|
||||
original_chunk = response_obj.get("original_chunk", None)
|
||||
model_response.id = original_chunk.id
|
||||
try:
|
||||
delta = dict(original_chunk.choices[0].delta)
|
||||
model_response.choices[0].delta = Delta(**delta)
|
||||
except:
|
||||
model_response.choices[0].delta = Delta()
|
||||
if len(original_chunk.choices) > 0:
|
||||
try:
|
||||
delta = dict(original_chunk.choices[0].delta)
|
||||
model_response.choices[0].delta = Delta(**delta)
|
||||
except Exception as e:
|
||||
model_response.choices[0].delta = Delta()
|
||||
else:
|
||||
return
|
||||
model_response.system_fingerprint = original_chunk.system_fingerprint
|
||||
if self.sent_first_chunk == False:
|
||||
model_response.choices[0].delta["role"] = "assistant"
|
||||
|
@ -5232,10 +5234,8 @@ class CustomStreamWrapper:
|
|||
else:
|
||||
chunk = next(self.completion_stream)
|
||||
|
||||
print_verbose(f"chunk in __next__: {chunk}")
|
||||
if chunk is not None and chunk != b'':
|
||||
response = self.chunk_creator(chunk=chunk)
|
||||
print_verbose(f"response in __next__: {response}")
|
||||
if response is not None:
|
||||
return response
|
||||
except StopIteration:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue