mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
aiohttp_openai/
fixes - allow using aiohttp_openai/gpt-4o
(#7598)
* fixes for get_complete_url * update aiohttp tests * fix event loop for aiohtto * ci/cd run again * test_aiohttp_openai
This commit is contained in:
parent
744beac754
commit
2ca0977921
5 changed files with 106 additions and 61 deletions
|
@ -9,7 +9,7 @@ New config to ensure we introduce this without causing breaking changes for user
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional
|
from typing import TYPE_CHECKING, Any, List, Optional
|
||||||
|
|
||||||
import httpx
|
from aiohttp import ClientResponse
|
||||||
|
|
||||||
from litellm.llms.openai_like.chat.transformation import OpenAILikeChatConfig
|
from litellm.llms.openai_like.chat.transformation import OpenAILikeChatConfig
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
@ -24,6 +24,22 @@ else:
|
||||||
|
|
||||||
|
|
||||||
class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
|
class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
|
||||||
|
def get_complete_url(
|
||||||
|
self,
|
||||||
|
api_base: str,
|
||||||
|
model: str,
|
||||||
|
optional_params: dict,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Ensure - /v1/chat/completions is at the end of the url
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not api_base.endswith("/chat/completions"):
|
||||||
|
api_base += "/chat/completions"
|
||||||
|
return api_base
|
||||||
|
|
||||||
def validate_environment(
|
def validate_environment(
|
||||||
self,
|
self,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
|
@ -33,12 +49,12 @@ class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
return {}
|
return {"Authorization": f"Bearer {api_key}"}
|
||||||
|
|
||||||
def transform_response(
|
async def transform_response( # type: ignore
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
raw_response: httpx.Response,
|
raw_response: ClientResponse,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
|
@ -49,4 +65,5 @@ class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
json_mode: Optional[bool] = None,
|
json_mode: Optional[bool] = None,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
return ModelResponse(**raw_response.json())
|
_json_response = await raw_response.json()
|
||||||
|
return ModelResponse(**_json_response)
|
||||||
|
|
|
@ -172,9 +172,19 @@ class BaseLLMAIOHTTPHandler:
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
_json_response = await _response.json()
|
_transformed_response = await provider_config.transform_response( # type: ignore
|
||||||
|
model=model,
|
||||||
return _json_response
|
raw_response=_response, # type: ignore
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
request_data=data,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
return _transformed_response
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -1478,6 +1478,43 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "aiohttp_openai":
|
||||||
|
# NEW aiohttp provider for 10-100x higher RPS
|
||||||
|
api_base = (
|
||||||
|
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
|
||||||
|
or litellm.api_base
|
||||||
|
or get_secret("OPENAI_API_BASE")
|
||||||
|
or "https://api.openai.com/v1"
|
||||||
|
)
|
||||||
|
# set API KEY
|
||||||
|
api_key = (
|
||||||
|
api_key
|
||||||
|
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
|
||||||
|
or litellm.openai_key
|
||||||
|
or get_secret("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = headers or litellm.headers
|
||||||
|
|
||||||
|
if extra_headers is not None:
|
||||||
|
optional_params["extra_headers"] = extra_headers
|
||||||
|
response = base_llm_aiohttp_handler.completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
headers=headers,
|
||||||
|
model_response=model_response,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
acompletion=acompletion,
|
||||||
|
logging_obj=logging,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
encoding=encoding,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
elif (
|
elif (
|
||||||
model in litellm.open_ai_chat_completion_models
|
model in litellm.open_ai_chat_completion_models
|
||||||
or custom_llm_provider == "custom_openai"
|
or custom_llm_provider == "custom_openai"
|
||||||
|
@ -2802,42 +2839,6 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
response = model_response
|
response = model_response
|
||||||
elif custom_llm_provider == "aiohttp_openai":
|
|
||||||
api_base = (
|
|
||||||
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
|
|
||||||
or litellm.api_base
|
|
||||||
or get_secret("OPENAI_API_BASE")
|
|
||||||
or "https://api.openai.com/v1"
|
|
||||||
)
|
|
||||||
# set API KEY
|
|
||||||
api_key = (
|
|
||||||
api_key
|
|
||||||
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
|
|
||||||
or litellm.openai_key
|
|
||||||
or get_secret("OPENAI_API_KEY")
|
|
||||||
)
|
|
||||||
|
|
||||||
headers = headers or litellm.headers
|
|
||||||
|
|
||||||
if extra_headers is not None:
|
|
||||||
optional_params["extra_headers"] = extra_headers
|
|
||||||
response = base_llm_aiohttp_handler.completion(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
headers=headers,
|
|
||||||
model_response=model_response,
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
acompletion=acompletion,
|
|
||||||
logging_obj=logging,
|
|
||||||
optional_params=optional_params,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
timeout=timeout,
|
|
||||||
client=client,
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
encoding=encoding,
|
|
||||||
stream=stream,
|
|
||||||
)
|
|
||||||
elif custom_llm_provider == "custom":
|
elif custom_llm_provider == "custom":
|
||||||
url = litellm.api_base or api_base or ""
|
url = litellm.api_base or api_base or ""
|
||||||
if url is None or url == "":
|
if url is None or url == "":
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import asyncio
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
|
@ -12,31 +12,38 @@ sys.path.insert(
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def event_loop():
|
||||||
|
"""Create an instance of the default event loop for each test session."""
|
||||||
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||||
|
yield loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
def setup_and_teardown():
|
def setup_and_teardown(event_loop): # Add event_loop as a dependency
|
||||||
"""
|
curr_dir = os.getcwd()
|
||||||
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
|
sys.path.insert(0, os.path.abspath("../.."))
|
||||||
"""
|
|
||||||
curr_dir = os.getcwd() # Get the current working directory
|
|
||||||
sys.path.insert(
|
|
||||||
0, os.path.abspath("../..")
|
|
||||||
) # Adds the project directory to the system path
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
|
|
||||||
importlib.reload(litellm)
|
importlib.reload(litellm)
|
||||||
import asyncio
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
# Set the event loop from the fixture
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(event_loop)
|
||||||
|
|
||||||
print(litellm)
|
print(litellm)
|
||||||
# from litellm import Router, completion, aembedding, acompletion, embedding
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Teardown code (executes after the yield point)
|
# Clean up any pending tasks
|
||||||
loop.close() # Close the loop created earlier
|
pending = asyncio.all_tasks(event_loop)
|
||||||
asyncio.set_event_loop(None) # Remove the reference to the loop
|
for task in pending:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
# Run the event loop until all tasks are cancelled
|
||||||
|
if pending:
|
||||||
|
event_loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||||
|
|
||||||
|
|
||||||
def pytest_collection_modifyitems(config, items):
|
def pytest_collection_modifyitems(config, items):
|
||||||
|
|
|
@ -11,7 +11,7 @@ sys.path.insert(
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio()
|
||||||
async def test_aiohttp_openai():
|
async def test_aiohttp_openai():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
response = await litellm.acompletion(
|
response = await litellm.acompletion(
|
||||||
|
@ -21,3 +21,13 @@ async def test_aiohttp_openai():
|
||||||
api_key="fake-key",
|
api_key="fake-key",
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_aiohttp_openai_gpt_4o():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="aiohttp_openai/gpt-4o",
|
||||||
|
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue