mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(azure.py): use openai client sdk for handling sync+async calling
This commit is contained in:
parent
3285113d2d
commit
bf0f8b824c
7 changed files with 136 additions and 167 deletions
|
@ -4,8 +4,9 @@ from .base import BaseLLM
|
||||||
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object
|
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from litellm import OpenAIConfig
|
from litellm import OpenAIConfig
|
||||||
import litellm
|
import litellm, json
|
||||||
import httpx
|
import httpx
|
||||||
|
from openai import AzureOpenAI, AsyncAzureOpenAI
|
||||||
|
|
||||||
class AzureOpenAIError(Exception):
|
class AzureOpenAIError(Exception):
|
||||||
def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None):
|
def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None):
|
||||||
|
@ -73,12 +74,10 @@ class AzureOpenAIConfig(OpenAIConfig):
|
||||||
top_p)
|
top_p)
|
||||||
|
|
||||||
class AzureChatCompletion(BaseLLM):
|
class AzureChatCompletion(BaseLLM):
|
||||||
_client_session: Optional[httpx.Client] = None
|
|
||||||
_aclient_session: Optional[httpx.AsyncClient] = None
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def validate_environment(self, api_key, azure_ad_token):
|
def validate_environment(self, api_key, azure_ad_token):
|
||||||
headers = {
|
headers = {
|
||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
|
@ -110,17 +109,12 @@ class AzureChatCompletion(BaseLLM):
|
||||||
self._client_session = self.create_client_session()
|
self._client_session = self.create_client_session()
|
||||||
exception_mapping_worked = False
|
exception_mapping_worked = False
|
||||||
try:
|
try:
|
||||||
if headers is None:
|
|
||||||
headers = self.validate_environment(api_key=api_key, azure_ad_token=azure_ad_token)
|
|
||||||
|
|
||||||
if model is None or messages is None:
|
if model is None or messages is None:
|
||||||
raise AzureOpenAIError(status_code=422, message=f"Missing model or messages")
|
raise AzureOpenAIError(status_code=422, message=f"Missing model or messages")
|
||||||
# Ensure api_base ends with a trailing slash
|
|
||||||
if not api_base.endswith('/'):
|
|
||||||
api_base += '/'
|
|
||||||
|
|
||||||
api_base = api_base + f"openai/deployments/{model}/chat/completions?api-version={api_version}"
|
|
||||||
data = {
|
data = {
|
||||||
|
"model": model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
**optional_params
|
**optional_params
|
||||||
}
|
}
|
||||||
|
@ -137,41 +131,34 @@ class AzureChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
if acompletion is True:
|
if acompletion is True:
|
||||||
if optional_params.get("stream", False):
|
if optional_params.get("stream", False):
|
||||||
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
|
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token)
|
||||||
else:
|
else:
|
||||||
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
|
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token)
|
||||||
elif "stream" in optional_params and optional_params["stream"] == True:
|
elif "stream" in optional_params and optional_params["stream"] == True:
|
||||||
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
|
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token)
|
||||||
else:
|
else:
|
||||||
response = self._client_session.post(
|
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token)
|
||||||
url=api_base,
|
response = azure_client.chat.completions.create(**data) # type: ignore
|
||||||
json=data,
|
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||||
headers=headers,
|
|
||||||
timeout=litellm.request_timeout
|
|
||||||
)
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise AzureOpenAIError(status_code=response.status_code, message=response.text)
|
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
|
||||||
return convert_to_model_response_object(response_object=response.json(), model_response_object=model_response)
|
|
||||||
except AzureOpenAIError as e:
|
except AzureOpenAIError as e:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse):
|
async def acompletion(self,
|
||||||
if self._aclient_session is None:
|
api_key: str,
|
||||||
self._aclient_session = self.create_aclient_session()
|
api_version: str,
|
||||||
client = self._aclient_session
|
model: str,
|
||||||
|
api_base: str,
|
||||||
|
data: dict,
|
||||||
|
headers: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
azure_ad_token: Optional[str]=None, ):
|
||||||
try:
|
try:
|
||||||
response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout)
|
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token)
|
||||||
response_json = response.json()
|
response = await azure_client.chat.completions.create(**data)
|
||||||
if response.status_code != 200:
|
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||||
raise AzureOpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response)
|
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
|
||||||
return convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e,httpx.TimeoutException):
|
if isinstance(e,httpx.TimeoutException):
|
||||||
raise AzureOpenAIError(status_code=500, message="Request Timeout Error")
|
raise AzureOpenAIError(status_code=500, message="Request Timeout Error")
|
||||||
|
@ -183,74 +170,52 @@ class AzureChatCompletion(BaseLLM):
|
||||||
def streaming(self,
|
def streaming(self,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
api_base: str,
|
api_base: str,
|
||||||
|
api_key: str,
|
||||||
|
api_version: str,
|
||||||
data: dict,
|
data: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
model: str
|
model: str,
|
||||||
|
azure_ad_token: Optional[str]=None,
|
||||||
):
|
):
|
||||||
if self._client_session is None:
|
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token)
|
||||||
self._client_session = self.create_client_session()
|
response = azure_client.chat.completions.create(**data)
|
||||||
with self._client_session.stream(
|
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
||||||
url=f"{api_base}",
|
for transformed_chunk in streamwrapper:
|
||||||
json=data,
|
yield transformed_chunk
|
||||||
headers=headers,
|
|
||||||
method="POST",
|
|
||||||
timeout=litellm.request_timeout
|
|
||||||
) as response:
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise AzureOpenAIError(status_code=response.status_code, message="An error occurred while streaming")
|
|
||||||
|
|
||||||
completion_stream = response.iter_lines()
|
|
||||||
streamwrapper = CustomStreamWrapper(completion_stream=completion_stream, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
|
||||||
for transformed_chunk in streamwrapper:
|
|
||||||
yield transformed_chunk
|
|
||||||
|
|
||||||
async def async_streaming(self,
|
async def async_streaming(self,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
api_base: str,
|
api_base: str,
|
||||||
|
api_key: str,
|
||||||
|
api_version: str,
|
||||||
data: dict,
|
data: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
model: str):
|
model: str,
|
||||||
if self._aclient_session is None:
|
azure_ad_token: Optional[str]=None):
|
||||||
self._aclient_session = self.create_aclient_session()
|
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token)
|
||||||
client = self._aclient_session
|
response = await azure_client.chat.completions.create(**data)
|
||||||
async with client.stream(
|
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
||||||
url=f"{api_base}",
|
async for transformed_chunk in streamwrapper:
|
||||||
json=data,
|
yield transformed_chunk
|
||||||
headers=headers,
|
|
||||||
method="POST",
|
|
||||||
timeout=litellm.request_timeout
|
|
||||||
) as response:
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise AzureOpenAIError(status_code=response.status_code, message=response.text)
|
|
||||||
|
|
||||||
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
|
||||||
async for transformed_chunk in streamwrapper:
|
|
||||||
yield transformed_chunk
|
|
||||||
|
|
||||||
def embedding(self,
|
def embedding(self,
|
||||||
model: str,
|
model: str,
|
||||||
input: list,
|
input: list,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
api_base: str,
|
api_base: str,
|
||||||
azure_ad_token: str,
|
|
||||||
api_version: str,
|
api_version: str,
|
||||||
logging_obj=None,
|
logging_obj=None,
|
||||||
model_response=None,
|
model_response=None,
|
||||||
optional_params=None,):
|
optional_params=None,
|
||||||
|
azure_ad_token: Optional[str]=None):
|
||||||
super().embedding()
|
super().embedding()
|
||||||
exception_mapping_worked = False
|
exception_mapping_worked = False
|
||||||
if self._client_session is None:
|
if self._client_session is None:
|
||||||
self._client_session = self.create_client_session()
|
self._client_session = self.create_client_session()
|
||||||
try:
|
try:
|
||||||
headers = self.validate_environment(api_key, azure_ad_token=azure_ad_token)
|
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token)
|
||||||
# Ensure api_base ends with a trailing slash
|
|
||||||
if not api_base.endswith('/'):
|
|
||||||
api_base += '/'
|
|
||||||
|
|
||||||
api_base = api_base + f"openai/deployments/{model}/embeddings?api-version={api_version}"
|
|
||||||
model = model
|
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"input": input,
|
"input": input,
|
||||||
|
@ -263,10 +228,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
)
|
)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
response = self._client_session.post(
|
response = azure_client.embeddings.create(**data) # type: ignore
|
||||||
api_base, headers=headers, json=data, timeout=litellm.request_timeout
|
|
||||||
)
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=input,
|
input=input,
|
||||||
|
@ -275,9 +238,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
original_response=response,
|
original_response=response,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code!=200:
|
embedding_response = json.loads(response.model_dump_json())
|
||||||
raise AzureOpenAIError(message=response.text, status_code=response.status_code)
|
|
||||||
embedding_response = response.json()
|
|
||||||
output_data = []
|
output_data = []
|
||||||
for idx, embedding in enumerate(embedding_response["data"]):
|
for idx, embedding in enumerate(embedding_response["data"]):
|
||||||
output_data.append(
|
output_data.append(
|
||||||
|
|
|
@ -36,7 +36,7 @@ def test_sync_response_anyscale():
|
||||||
|
|
||||||
# test_sync_response_anyscale()
|
# test_sync_response_anyscale()
|
||||||
|
|
||||||
def test_async_response():
|
def test_async_response_openai():
|
||||||
import asyncio
|
import asyncio
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
async def test_get_response():
|
async def test_get_response():
|
||||||
|
@ -44,13 +44,27 @@ def test_async_response():
|
||||||
messages = [{"content": user_message, "role": "user"}]
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
try:
|
try:
|
||||||
response = await acompletion(model="gpt-3.5-turbo", messages=messages)
|
response = await acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||||
# response = await response
|
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred: {e}")
|
pytest.fail(f"An exception occurred: {e}")
|
||||||
|
|
||||||
asyncio.run(test_get_response())
|
asyncio.run(test_get_response())
|
||||||
test_async_response()
|
|
||||||
|
def test_async_response_azure():
|
||||||
|
import asyncio
|
||||||
|
litellm.set_verbose = True
|
||||||
|
async def test_get_response():
|
||||||
|
user_message = "Hello, how are you?"
|
||||||
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
|
try:
|
||||||
|
response = await acompletion(model="azure/chatgpt-v-2", messages=messages)
|
||||||
|
print(f"response: {response}")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {e}")
|
||||||
|
|
||||||
|
asyncio.run(test_get_response())
|
||||||
|
|
||||||
|
|
||||||
def test_async_anyscale_response():
|
def test_async_anyscale_response():
|
||||||
import asyncio
|
import asyncio
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
@ -73,7 +87,7 @@ def test_get_response_streaming():
|
||||||
messages = [{"content": user_message, "role": "user"}]
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True)
|
response = await acompletion(model="azure/chatgpt-v-2", messages=messages, stream=True)
|
||||||
print(type(response))
|
print(type(response))
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
|
|
@ -481,7 +481,7 @@ def test_completion_openai_litellm_key():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
test_completion_openai_litellm_key()
|
# test_completion_openai_litellm_key()
|
||||||
|
|
||||||
def test_completion_openrouter1():
|
def test_completion_openrouter1():
|
||||||
try:
|
try:
|
||||||
|
@ -562,6 +562,8 @@ def test_completion_azure():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
test_completion_azure()
|
||||||
|
|
||||||
def test_azure_openai_ad_token():
|
def test_azure_openai_ad_token():
|
||||||
# this tests if the azure ad token is set in the request header
|
# this tests if the azure ad token is set in the request header
|
||||||
# the request can fail since azure ad tokens expire after 30 mins, but the header MUST have the azure ad token
|
# the request can fail since azure ad tokens expire after 30 mins, but the header MUST have the azure ad token
|
||||||
|
|
|
@ -20,7 +20,7 @@ def test_openai_embedding():
|
||||||
# print(f"response: {str(response)}")
|
# print(f"response: {str(response)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
test_openai_embedding()
|
# test_openai_embedding()
|
||||||
|
|
||||||
def test_openai_azure_embedding_simple():
|
def test_openai_azure_embedding_simple():
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,69 +1,69 @@
|
||||||
# import sys, os
|
import sys, os
|
||||||
# import traceback
|
import traceback
|
||||||
# from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
# import copy
|
import copy
|
||||||
|
|
||||||
# load_dotenv()
|
load_dotenv()
|
||||||
# sys.path.insert(
|
sys.path.insert(
|
||||||
# 0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
# ) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
# import asyncio
|
import asyncio
|
||||||
# from litellm import Router, Timeout
|
from litellm import Router, Timeout
|
||||||
|
|
||||||
|
|
||||||
# async def call_acompletion(semaphore, router: Router, input_data):
|
async def call_acompletion(semaphore, router: Router, input_data):
|
||||||
# async with semaphore:
|
async with semaphore:
|
||||||
# try:
|
try:
|
||||||
# # Use asyncio.wait_for to set a timeout for the task
|
# Use asyncio.wait_for to set a timeout for the task
|
||||||
# response = await router.acompletion(**input_data)
|
response = await router.acompletion(**input_data)
|
||||||
# # Handle the response as needed
|
# Handle the response as needed
|
||||||
# return response
|
return response
|
||||||
# except Timeout:
|
except Timeout:
|
||||||
# print(f"Task timed out: {input_data}")
|
print(f"Task timed out: {input_data}")
|
||||||
# return None # You may choose to return something else or raise an exception
|
return None # You may choose to return something else or raise an exception
|
||||||
|
|
||||||
|
|
||||||
# async def main():
|
async def main():
|
||||||
# # Initialize the Router
|
# Initialize the Router
|
||||||
# model_list= [{
|
model_list= [{
|
||||||
# "model_name": "gpt-3.5-turbo",
|
"model_name": "gpt-3.5-turbo",
|
||||||
# "litellm_params": {
|
"litellm_params": {
|
||||||
# "model": "gpt-3.5-turbo",
|
"model": "gpt-3.5-turbo",
|
||||||
# "api_key": os.getenv("OPENAI_API_KEY"),
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
# },
|
},
|
||||||
# }, {
|
}, {
|
||||||
# "model_name": "gpt-3.5-turbo",
|
"model_name": "gpt-3.5-turbo",
|
||||||
# "litellm_params": {
|
"litellm_params": {
|
||||||
# "model": "azure/chatgpt-v-2",
|
"model": "azure/chatgpt-v-2",
|
||||||
# "api_key": os.getenv("AZURE_API_KEY"),
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
# "api_base": os.getenv("AZURE_API_BASE"),
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
# "api_version": os.getenv("AZURE_API_VERSION")
|
"api_version": os.getenv("AZURE_API_VERSION")
|
||||||
# },
|
},
|
||||||
# }, {
|
}, {
|
||||||
# "model_name": "gpt-3.5-turbo",
|
"model_name": "gpt-3.5-turbo",
|
||||||
# "litellm_params": {
|
"litellm_params": {
|
||||||
# "model": "azure/chatgpt-functioncalling",
|
"model": "azure/chatgpt-functioncalling",
|
||||||
# "api_key": os.getenv("AZURE_API_KEY"),
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
# "api_base": os.getenv("AZURE_API_BASE"),
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
# "api_version": os.getenv("AZURE_API_VERSION")
|
"api_version": os.getenv("AZURE_API_VERSION")
|
||||||
# },
|
},
|
||||||
# }]
|
}]
|
||||||
# router = Router(model_list=model_list, num_retries=3, timeout=10)
|
router = Router(model_list=model_list, num_retries=3, timeout=10)
|
||||||
|
|
||||||
# # Create a semaphore with a capacity of 100
|
# Create a semaphore with a capacity of 100
|
||||||
# semaphore = asyncio.Semaphore(100)
|
semaphore = asyncio.Semaphore(100)
|
||||||
|
|
||||||
# # List to hold all task references
|
# List to hold all task references
|
||||||
# tasks = []
|
tasks = []
|
||||||
|
|
||||||
# # Launch 1000 tasks
|
# Launch 1000 tasks
|
||||||
# for _ in range(1000):
|
for _ in range(1000):
|
||||||
# task = asyncio.create_task(call_acompletion(semaphore, router, {"model": "gpt-3.5-turbo", "messages": [{"role":"user", "content": "Hey, how's it going?"}]}))
|
task = asyncio.create_task(call_acompletion(semaphore, router, {"model": "gpt-3.5-turbo", "messages": [{"role":"user", "content": "Hey, how's it going?"}]}))
|
||||||
# tasks.append(task)
|
tasks.append(task)
|
||||||
|
|
||||||
# # Wait for all tasks to complete
|
# Wait for all tasks to complete
|
||||||
# responses = await asyncio.gather(*tasks)
|
responses = await asyncio.gather(*tasks)
|
||||||
# # Process responses as needed
|
# Process responses as needed
|
||||||
# print(f"NUMBER OF COMPLETED TASKS: {len(responses)}")
|
print(f"NUMBER OF COMPLETED TASKS: {len(responses)}")
|
||||||
# # Run the main function
|
# Run the main function
|
||||||
# asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
|
@ -374,7 +374,7 @@ def test_completion_azure_stream():
|
||||||
print(f"completion_response: {complete_response}")
|
print(f"completion_response: {complete_response}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# test_completion_azure_stream()
|
test_completion_azure_stream()
|
||||||
|
|
||||||
def test_completion_claude_stream():
|
def test_completion_claude_stream():
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -4505,11 +4505,12 @@ class CustomStreamWrapper:
|
||||||
text = ""
|
text = ""
|
||||||
is_finished = False
|
is_finished = False
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
if str_line.choices[0].delta.content is not None:
|
if len(str_line.choices) > 0:
|
||||||
text = str_line.choices[0].delta.content
|
if str_line.choices[0].delta.content is not None:
|
||||||
if str_line.choices[0].finish_reason:
|
text = str_line.choices[0].delta.content
|
||||||
is_finished = True
|
if str_line.choices[0].finish_reason:
|
||||||
finish_reason = str_line.choices[0].finish_reason
|
is_finished = True
|
||||||
|
finish_reason = str_line.choices[0].finish_reason
|
||||||
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
|
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
@ -4642,15 +4643,6 @@ class CustomStreamWrapper:
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
||||||
elif self.custom_llm_provider and self.custom_llm_provider == "azure":
|
|
||||||
response_obj = self.handle_azure_chunk(chunk)
|
|
||||||
completion_obj["content"] = response_obj["text"]
|
|
||||||
print_verbose(f"response_obj: {response_obj}")
|
|
||||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
|
||||||
print_verbose(f"len(completion_obj['content']: {len(completion_obj['content'])}")
|
|
||||||
if response_obj["is_finished"]:
|
|
||||||
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
|
||||||
print_verbose(f"model_response finish reason 2: {model_response.choices[0].finish_reason}")
|
|
||||||
elif self.custom_llm_provider and self.custom_llm_provider == "maritalk":
|
elif self.custom_llm_provider and self.custom_llm_provider == "maritalk":
|
||||||
response_obj = self.handle_maritalk_chunk(chunk)
|
response_obj = self.handle_maritalk_chunk(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue