diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index e7baaf834e..efc8808212 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -154,18 +154,36 @@ class OpenAITextCompletionConfig(): and v is not None} class OpenAIChatCompletion(BaseLLM): - _client_session: Optional[httpx.Client] = None - _aclient_session: Optional[httpx.AsyncClient] = None + openai_client: Optional[openai.Client] = None + openai_aclient: Optional[openai.AsyncClient] = None def __init__(self) -> None: super().__init__() + self.openai_client = openai.OpenAI() + self.openai_aclient = openai.AsyncOpenAI() - def validate_environment(self, api_key): - headers = { - "content-type": "application/json", - } - if api_key: - headers["Authorization"] = f"Bearer {api_key}" + def validate_environment(self, api_key, api_base, headers): + if headers is None: + headers = { + "content-type": "application/json", + } + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + self.openai_client.api_key = api_key + self.openai_aclient.api_key = api_key + if api_base: + if self.openai_client.base_url is None or self.openai_client.base_url != api_base: + if api_base.endswith("/"): + self.openai_client._base_url = httpx.URL(url=api_base) + else: + self.openai_client._base_url = httpx.URL(url=api_base+"/") + if self.openai_aclient.base_url is None or self.openai_aclient.base_url != api_base: + if api_base.endswith("/"): + self.openai_aclient._base_url = httpx.URL(url=api_base) + else: + self.openai_aclient._base_url = httpx.URL(url=api_base+"/") + return headers def _retry_request(self, *args, **kwargs): @@ -191,13 +209,9 @@ class OpenAIChatCompletion(BaseLLM): logger_fn=None, headers: Optional[dict]=None): super().completion() - if self._client_session is None: - self._client_session = self.create_client_session() exception_mapping_worked = False try: - if headers is None: - headers = self.validate_environment(api_key=api_key) - api_base = f"{api_base}/chat/completions" + headers = self.validate_environment(api_key=api_key, api_base=api_base, headers=headers) if model is None or messages is None: raise OpenAIError(status_code=422, message=f"Missing model or messages") @@ -224,23 +238,8 @@ class OpenAIChatCompletion(BaseLLM): elif optional_params.get("stream", False): return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) else: - if model in litellm.models_by_provider["openai"]: - if api_key: - openai.api_key = api_key - response = openai.chat.completions.create(**data) - return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) - else: - response = requests.post( - url=api_base, - json=data, - headers=headers, - timeout=600 # Set a 10-minute timeout for both connection and read - ) - if response.status_code != 200: - raise OpenAIError(status_code=response.status_code, message=response.text) - - ## RESPONSE OBJECT - return convert_to_model_response_object(response_object=response.json(), model_response_object=model_response) + response = self.openai_client.chat.completions.create(**data) + return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) except Exception as e: if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e): # reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility @@ -270,19 +269,11 @@ class OpenAIChatCompletion(BaseLLM): api_base: str, data: dict, headers: dict, model_response: ModelResponse): - kwargs = locals() + response = None try: - async with httpx.AsyncClient() as client: - response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout) - response_json = response.json() - if response.status_code != 200: - raise OpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response) - - ## RESPONSE OBJECT - return convert_to_model_response_object(response_object=response_json, model_response_object=model_response) + response = await self.openai_aclient.chat.completions.create(**data) + return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) except Exception as e: - if isinstance(e, httpx.TimeoutException): - raise OpenAIError(status_code=500, message="Request Timeout Error") if response and hasattr(response, "text"): raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}") else: @@ -296,20 +287,10 @@ class OpenAIChatCompletion(BaseLLM): model_response: ModelResponse, model: str ): - with httpx.stream( - url=f"{api_base}", # type: ignore - json=data, - headers=headers, - method="POST", - timeout=litellm.request_timeout - ) as response: - if response.status_code != 200: - raise OpenAIError(status_code=response.status_code, message=response.text()) # type: ignore - - completion_stream = response.iter_lines() - streamwrapper = CustomStreamWrapper(completion_stream=completion_stream, model=model, custom_llm_provider="openai",logging_obj=logging_obj) - for transformed_chunk in streamwrapper: - yield transformed_chunk + response = self.openai_client.chat.completions.create(**data) + streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj) + for transformed_chunk in streamwrapper: + yield transformed_chunk async def async_streaming(self, logging_obj, @@ -318,20 +299,11 @@ class OpenAIChatCompletion(BaseLLM): headers: dict, model_response: ModelResponse, model: str): - client = httpx.AsyncClient() - async with client.stream( - url=f"{api_base}", - json=data, - headers=headers, - method="POST", - timeout=litellm.request_timeout - ) as response: - if response.status_code != 200: - raise OpenAIError(status_code=response.status_code, message=response.text()) # type: ignore - - streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="openai",logging_obj=logging_obj) - async for transformed_chunk in streamwrapper: - yield transformed_chunk + response = await self.openai_aclient.chat.completions.create(**data) + streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj) + async for transformed_chunk in streamwrapper: + yield transformed_chunk + def embedding(self, model: str, diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index df34a25496..bbb23a0c2f 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -12,7 +12,7 @@ class VertexAIError(Exception): def __init__(self, status_code, message): self.status_code = status_code self.message = message - self.request = httpx.Request(method="POST", url="https://api.ai21.com/studio/v1/") + self.request = httpx.Request(method="POST", url=" https://cloud.google.com/vertex-ai/") self.response = httpx.Response(status_code=status_code, request=self.request) super().__init__( self.message diff --git a/litellm/tests/test_async_fn.py b/litellm/tests/test_async_fn.py index 920fff3cc5..80981a62e9 100644 --- a/litellm/tests/test_async_fn.py +++ b/litellm/tests/test_async_fn.py @@ -23,7 +23,18 @@ def test_sync_response(): response = completion(model="gpt-3.5-turbo", messages=messages, api_key=os.environ["OPENAI_API_KEY"]) except Exception as e: pytest.fail(f"An exception occurred: {e}") +# test_sync_response() +def test_sync_response_anyscale(): + litellm.set_verbose = True + user_message = "Hello, how are you?" + messages = [{"content": user_message, "role": "user"}] + try: + response = completion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages) + except Exception as e: + pytest.fail(f"An exception occurred: {e}") + +# test_sync_response_anyscale() def test_async_response(): import asyncio @@ -32,13 +43,28 @@ def test_async_response(): user_message = "Hello, how are you?" messages = [{"content": user_message, "role": "user"}] try: - response = await acompletion(model="huggingface/HuggingFaceH4/zephyr-7b-beta", messages=messages) + response = await acompletion(model="gpt-3.5-turbo", messages=messages) + # response = await response + print(f"response: {response}") + except Exception as e: + pytest.fail(f"An exception occurred: {e}") + + asyncio.run(test_get_response()) + +def test_async_anyscale_response(): + import asyncio + litellm.set_verbose = True + async def test_get_response(): + user_message = "Hello, how are you?" + messages = [{"content": user_message, "role": "user"}] + try: + response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages) + # response = await response print(f"response: {response}") except Exception as e: pytest.fail(f"An exception occurred: {e}") asyncio.run(test_get_response()) -# test_async_response() def test_get_response_streaming(): import asyncio @@ -70,7 +96,7 @@ def test_get_response_streaming(): asyncio.run(test_async_call()) -# test_get_response_streaming() +test_get_response_streaming() def test_get_response_non_openai_streaming(): import asyncio @@ -79,7 +105,7 @@ def test_get_response_non_openai_streaming(): user_message = "Hello, how are you?" messages = [{"content": user_message, "role": "user"}] try: - response = await acompletion(model="huggingface/HuggingFaceH4/zephyr-7b-beta", messages=messages, stream=True) + response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, stream=True) print(type(response)) import inspect diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 314b8736b6..a7f7772687 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -374,7 +374,7 @@ def test_completion_azure_stream(): print(f"completion_response: {complete_response}") except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_azure_stream() +# test_completion_azure_stream() def test_completion_claude_stream(): try: @@ -829,6 +829,7 @@ def ai21_completion_call_bad_key(): def test_openai_chat_completion_call(): try: litellm.set_verbose = False + print(f"making openai chat completion call") response = completion( model="gpt-3.5-turbo", messages=messages, stream=True ) @@ -848,7 +849,7 @@ def test_openai_chat_completion_call(): print(f"error occurred: {traceback.format_exc()}") pass -# test_openai_chat_completion_call() +test_openai_chat_completion_call() def test_openai_chat_completion_complete_response_call(): try: diff --git a/litellm/utils.py b/litellm/utils.py index 0e0a4ced9a..63c470f0a2 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4496,26 +4496,12 @@ class CustomStreamWrapper: text = "" is_finished = False finish_reason = None - if "data: [DONE]" in str_line: - # anyscale returns a [DONE] special char for streaming, this cannot be json loaded. This is the end of stream - text = "" + if str_line.choices[0].delta.content is not None: + text = str_line.choices[0].delta.content + if str_line.choices[0].finish_reason: is_finished = True - finish_reason = "stop" - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} - elif str_line.startswith("data:") and len(str_line[5:]) > 0: - str_line = str_line[5:] - data_json = json.loads(str_line) - print_verbose(f"delta content: {data_json['choices'][0]['delta']}") - text = data_json["choices"][0]["delta"].get("content", "") - if data_json["choices"][0].get("finish_reason", None): - is_finished = True - finish_reason = data_json["choices"][0]["finish_reason"] - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} - elif "error" in str_line: - raise ValueError(f"Unable to parse response. Original response: {str_line}") - else: - return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} - + finish_reason = str_line.choices[0].finish_reason + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} except Exception as e: traceback.print_exc() raise e