diff --git a/litellm/__init__.py b/litellm/__init__.py index 4ad362d509..8b7a13e4ff 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -369,7 +369,7 @@ from .llms.vertex_ai import VertexAIConfig from .llms.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig from .llms.maritalk import MaritTalkConfig -from .llms.bedrock import AmazonTitanConfig, AmazonAI21Config, AmazonAnthropicConfig, AmazonCohereConfig +from .llms.bedrock import AmazonTitanConfig, AmazonAI21Config, AmazonAnthropicConfig, AmazonCohereConfig, AmazonLlamaConfig from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig from .llms.azure import AzureOpenAIConfig from .main import * # type: ignore diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index efc8808212..6bfd1aad3c 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -154,8 +154,8 @@ class OpenAITextCompletionConfig(): and v is not None} class OpenAIChatCompletion(BaseLLM): - openai_client: Optional[openai.Client] = None - openai_aclient: Optional[openai.AsyncClient] = None + openai_client: openai.Client + openai_aclient: openai.AsyncClient def __init__(self) -> None: super().__init__() @@ -232,13 +232,13 @@ class OpenAIChatCompletion(BaseLLM): try: if acompletion is True: if optional_params.get("stream", False): - return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) + return self.async_streaming(logging_obj=logging_obj, data=data, model=model) else: - return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response) + return self.acompletion(data=data, model_response=model_response) elif optional_params.get("stream", False): - return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) + return self.streaming(logging_obj=logging_obj, data=data, model=model) else: - response = self.openai_client.chat.completions.create(**data) + response = self.openai_client.chat.completions.create(**data) # type: ignore return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) except Exception as e: if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e): @@ -266,8 +266,7 @@ class OpenAIChatCompletion(BaseLLM): raise e async def acompletion(self, - api_base: str, - data: dict, headers: dict, + data: dict, model_response: ModelResponse): response = None try: @@ -281,10 +280,7 @@ class OpenAIChatCompletion(BaseLLM): def streaming(self, logging_obj, - api_base: str, data: dict, - headers: dict, - model_response: ModelResponse, model: str ): response = self.openai_client.chat.completions.create(**data) @@ -294,10 +290,7 @@ class OpenAIChatCompletion(BaseLLM): async def async_streaming(self, logging_obj, - api_base: str, data: dict, - headers: dict, - model_response: ModelResponse, model: str): response = await self.openai_aclient.chat.completions.create(**data) streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj) @@ -315,10 +308,8 @@ class OpenAIChatCompletion(BaseLLM): optional_params=None,): super().embedding() exception_mapping_worked = False - if self._client_session is None: - self._client_session = self.create_client_session() try: - headers = self.validate_environment(api_key) + headers = self.validate_environment(api_key, api_base=api_base, headers=None) api_base = f"{api_base}/embeddings" model = model data = { @@ -334,9 +325,7 @@ class OpenAIChatCompletion(BaseLLM): additional_args={"complete_input_dict": data}, ) ## COMPLETION CALL - response = requests.post( - api_base, headers=headers, json=data, timeout=litellm.request_timeout - ) + response = self.openai_client.embeddings.create(**data) # type: ignore ## LOGGING logging_obj.post_call( input=input, @@ -344,10 +333,8 @@ class OpenAIChatCompletion(BaseLLM): additional_args={"complete_input_dict": data}, original_response=response, ) - - if response.status_code!=200: - raise OpenAIError(message=response.text, status_code=response.status_code) - embedding_response = response.json() + + embedding_response = json.loads(response.model_dump_json()) output_data = [] for idx, embedding in enumerate(embedding_response["data"]): output_data.append( diff --git a/litellm/tests/test_async_fn.py b/litellm/tests/test_async_fn.py index 80981a62e9..8bd0b2daa5 100644 --- a/litellm/tests/test_async_fn.py +++ b/litellm/tests/test_async_fn.py @@ -50,7 +50,7 @@ def test_async_response(): pytest.fail(f"An exception occurred: {e}") asyncio.run(test_get_response()) - +test_async_response() def test_async_anyscale_response(): import asyncio litellm.set_verbose = True diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index d533b7c42e..d9f1c9f673 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -20,7 +20,7 @@ def test_openai_embedding(): # print(f"response: {str(response)}") except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_openai_embedding() +test_openai_embedding() def test_openai_azure_embedding_simple(): try: