forked from phoenix/litellm-mirror
(fix) add linting check to ban creating AsyncHTTPHandler
during LLM calling (#6855)
* fix triton * fix TEXT_COMPLETION_CODESTRAL * fix REPLICATE * fix CLARIFAI * fix HUGGINGFACE * add test_no_async_http_handler_usage * fix PREDIBASE * fix anthropic use get_async_httpx_client * fix vertex fine tuning * fix dbricks get_async_httpx_client * fix get_async_httpx_client vertex * fix get_async_httpx_client * fix get_async_httpx_client * fix make_async_azure_httpx_request * fix check_for_async_http_handler * test: cleanup mistral model * add check for AsyncClient * fix check_for_async_http_handler * fix get_async_httpx_client * fix tests using in_memory_llm_clients_cache * fix langfuse import * fix import --------- Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>
This commit is contained in:
parent
71ebf47cef
commit
920f4c9f82
26 changed files with 288 additions and 62 deletions
|
@ -771,6 +771,7 @@ jobs:
|
||||||
- run: python ./tests/code_coverage_tests/litellm_logging_code_coverage.py
|
- run: python ./tests/code_coverage_tests/litellm_logging_code_coverage.py
|
||||||
- run: python ./tests/documentation_tests/test_env_keys.py
|
- run: python ./tests/documentation_tests/test_env_keys.py
|
||||||
- run: python ./tests/documentation_tests/test_api_docs.py
|
- run: python ./tests/documentation_tests/test_api_docs.py
|
||||||
|
- run: python ./tests/code_coverage_tests/ensure_async_clients_test.py
|
||||||
- run: helm lint ./deploy/charts/litellm-helm
|
- run: helm lint ./deploy/charts/litellm-helm
|
||||||
|
|
||||||
db_migration_disable_update_check:
|
db_migration_disable_update_check:
|
||||||
|
|
|
@ -133,7 +133,7 @@ use_client: bool = False
|
||||||
ssl_verify: Union[str, bool] = True
|
ssl_verify: Union[str, bool] = True
|
||||||
ssl_certificate: Optional[str] = None
|
ssl_certificate: Optional[str] = None
|
||||||
disable_streaming_logging: bool = False
|
disable_streaming_logging: bool = False
|
||||||
in_memory_llm_clients_cache: dict = {}
|
in_memory_llm_clients_cache: InMemoryCache = InMemoryCache()
|
||||||
safe_memory_mode: bool = False
|
safe_memory_mode: bool = False
|
||||||
enable_azure_ad_token_refresh: Optional[bool] = False
|
enable_azure_ad_token_refresh: Optional[bool] = False
|
||||||
### DEFAULT AZURE API VERSION ###
|
### DEFAULT AZURE API VERSION ###
|
||||||
|
|
|
@ -12,7 +12,11 @@ from typing_extensions import overload
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
from litellm.types.utils import EmbeddingResponse
|
from litellm.types.utils import EmbeddingResponse
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
|
@ -977,7 +981,10 @@ class AzureChatCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
|
||||||
async_handler = AsyncHTTPHandler(**_params) # type: ignore
|
async_handler = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.AZURE,
|
||||||
|
params=_params,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
async_handler = client # type: ignore
|
async_handler = client # type: ignore
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ import litellm
|
||||||
from litellm import LlmProviders
|
from litellm import LlmProviders
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.utils import ProviderField
|
from litellm.types.utils import ProviderField
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
|
@ -562,8 +563,9 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
|
|
||||||
_cache_key = f"hashed_api_key={hashed_api_key},api_base={api_base},timeout={timeout},max_retries={max_retries},organization={organization},is_async={is_async}"
|
_cache_key = f"hashed_api_key={hashed_api_key},api_base={api_base},timeout={timeout},max_retries={max_retries},organization={organization},is_async={is_async}"
|
||||||
|
|
||||||
if _cache_key in litellm.in_memory_llm_clients_cache:
|
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key)
|
||||||
return litellm.in_memory_llm_clients_cache[_cache_key]
|
if _cached_client:
|
||||||
|
return _cached_client
|
||||||
if is_async:
|
if is_async:
|
||||||
_new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI(
|
_new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -584,7 +586,11 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
## SAVE CACHE KEY
|
## SAVE CACHE KEY
|
||||||
litellm.in_memory_llm_clients_cache[_cache_key] = _new_client
|
litellm.in_memory_llm_clients_cache.set_cache(
|
||||||
|
key=_cache_key,
|
||||||
|
value=_new_client,
|
||||||
|
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
|
||||||
|
)
|
||||||
return _new_client
|
return _new_client
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -13,7 +13,11 @@ import httpx
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
|
||||||
from ..base import BaseLLM
|
from ..base import BaseLLM
|
||||||
|
@ -162,7 +166,10 @@ class AnthropicTextCompletion(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
):
|
):
|
||||||
if client is None:
|
if client is None:
|
||||||
client = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.ANTHROPIC,
|
||||||
|
params={"timeout": httpx.Timeout(timeout=600.0, connect=5.0)},
|
||||||
|
)
|
||||||
|
|
||||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||||
|
|
||||||
|
@ -198,7 +205,10 @@ class AnthropicTextCompletion(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
):
|
):
|
||||||
if client is None:
|
if client is None:
|
||||||
client = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.ANTHROPIC,
|
||||||
|
params={"timeout": httpx.Timeout(timeout=600.0, connect=5.0)},
|
||||||
|
)
|
||||||
|
|
||||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||||
|
|
||||||
|
|
|
@ -74,7 +74,10 @@ class AzureAIEmbedding(OpenAIChatCompletion):
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
) -> EmbeddingResponse:
|
) -> EmbeddingResponse:
|
||||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||||
client = AsyncHTTPHandler(timeout=timeout, concurrent_limit=1)
|
client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.AZURE_AI,
|
||||||
|
params={"timeout": timeout},
|
||||||
|
)
|
||||||
|
|
||||||
url = "{}/images/embeddings".format(api_base)
|
url = "{}/images/embeddings".format(api_base)
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,10 @@ import httpx
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
|
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
|
||||||
|
|
||||||
from .prompt_templates.factory import custom_prompt, prompt_factory
|
from .prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
@ -185,7 +188,10 @@ async def async_completion(
|
||||||
headers={},
|
headers={},
|
||||||
):
|
):
|
||||||
|
|
||||||
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
async_handler = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.CLARIFAI,
|
||||||
|
params={"timeout": 600.0},
|
||||||
|
)
|
||||||
response = await async_handler.post(
|
response = await async_handler.post(
|
||||||
url=model, headers=headers, data=json.dumps(data)
|
url=model, headers=headers, data=json.dumps(data)
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,7 +11,11 @@ import requests # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
from litellm.types.llms.bedrock import CohereEmbeddingRequest
|
from litellm.types.llms.bedrock import CohereEmbeddingRequest
|
||||||
from litellm.utils import Choices, Message, ModelResponse, Usage
|
from litellm.utils import Choices, Message, ModelResponse, Usage
|
||||||
|
|
||||||
|
@ -71,7 +75,10 @@ async def async_embedding(
|
||||||
)
|
)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
if client is None:
|
if client is None:
|
||||||
client = AsyncHTTPHandler(concurrent_limit=1, timeout=timeout)
|
client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.COHERE,
|
||||||
|
params={"timeout": timeout},
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||||
|
|
|
@ -7,6 +7,7 @@ import httpx
|
||||||
from httpx import USE_CLIENT_DEFAULT, AsyncHTTPTransport, HTTPTransport
|
from httpx import USE_CLIENT_DEFAULT, AsyncHTTPTransport, HTTPTransport
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.caching import InMemoryCache
|
||||||
|
|
||||||
from .types import httpxSpecialProvider
|
from .types import httpxSpecialProvider
|
||||||
|
|
||||||
|
@ -26,6 +27,7 @@ headers = {
|
||||||
|
|
||||||
# https://www.python-httpx.org/advanced/timeouts
|
# https://www.python-httpx.org/advanced/timeouts
|
||||||
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
|
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
|
||||||
|
_DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client for 1 hour
|
||||||
|
|
||||||
|
|
||||||
class AsyncHTTPHandler:
|
class AsyncHTTPHandler:
|
||||||
|
@ -476,8 +478,9 @@ def get_async_httpx_client(
|
||||||
pass
|
pass
|
||||||
|
|
||||||
_cache_key_name = "async_httpx_client" + _params_key_name + llm_provider
|
_cache_key_name = "async_httpx_client" + _params_key_name + llm_provider
|
||||||
if _cache_key_name in litellm.in_memory_llm_clients_cache:
|
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key_name)
|
||||||
return litellm.in_memory_llm_clients_cache[_cache_key_name]
|
if _cached_client:
|
||||||
|
return _cached_client
|
||||||
|
|
||||||
if params is not None:
|
if params is not None:
|
||||||
_new_client = AsyncHTTPHandler(**params)
|
_new_client = AsyncHTTPHandler(**params)
|
||||||
|
@ -485,7 +488,11 @@ def get_async_httpx_client(
|
||||||
_new_client = AsyncHTTPHandler(
|
_new_client = AsyncHTTPHandler(
|
||||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
)
|
)
|
||||||
litellm.in_memory_llm_clients_cache[_cache_key_name] = _new_client
|
litellm.in_memory_llm_clients_cache.set_cache(
|
||||||
|
key=_cache_key_name,
|
||||||
|
value=_new_client,
|
||||||
|
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
|
||||||
|
)
|
||||||
return _new_client
|
return _new_client
|
||||||
|
|
||||||
|
|
||||||
|
@ -505,13 +512,18 @@ def _get_httpx_client(params: Optional[dict] = None) -> HTTPHandler:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
_cache_key_name = "httpx_client" + _params_key_name
|
_cache_key_name = "httpx_client" + _params_key_name
|
||||||
if _cache_key_name in litellm.in_memory_llm_clients_cache:
|
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key_name)
|
||||||
return litellm.in_memory_llm_clients_cache[_cache_key_name]
|
if _cached_client:
|
||||||
|
return _cached_client
|
||||||
|
|
||||||
if params is not None:
|
if params is not None:
|
||||||
_new_client = HTTPHandler(**params)
|
_new_client = HTTPHandler(**params)
|
||||||
else:
|
else:
|
||||||
_new_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
_new_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||||
|
|
||||||
litellm.in_memory_llm_clients_cache[_cache_key_name] = _new_client
|
litellm.in_memory_llm_clients_cache.set_cache(
|
||||||
|
key=_cache_key_name,
|
||||||
|
value=_new_client,
|
||||||
|
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
|
||||||
|
)
|
||||||
return _new_client
|
return _new_client
|
||||||
|
|
|
@ -393,7 +393,10 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
if timeout is None:
|
if timeout is None:
|
||||||
timeout = httpx.Timeout(timeout=600.0, connect=5.0)
|
timeout = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
|
||||||
self.async_handler = AsyncHTTPHandler(timeout=timeout)
|
self.async_handler = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.DATABRICKS,
|
||||||
|
params={"timeout": timeout},
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self.async_handler.post(
|
response = await self.async_handler.post(
|
||||||
|
@ -610,7 +613,10 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||||
self.async_client = AsyncHTTPHandler(timeout=timeout) # type: ignore
|
self.async_client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.DATABRICKS,
|
||||||
|
params={"timeout": timeout},
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.async_client = client
|
self.async_client = client
|
||||||
|
|
||||||
|
|
|
@ -5,9 +5,14 @@ from typing import Any, Coroutine, Literal, Optional, Union
|
||||||
import httpx
|
import httpx
|
||||||
from openai.types.fine_tuning.fine_tuning_job import FineTuningJob, Hyperparameters
|
from openai.types.fine_tuning.fine_tuning_job import FineTuningJob, Hyperparameters
|
||||||
|
|
||||||
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.llms.base import BaseLLM
|
from litellm.llms.base import BaseLLM
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
VertexLLM,
|
VertexLLM,
|
||||||
)
|
)
|
||||||
|
@ -26,8 +31,9 @@ class VertexFineTuningAPI(VertexLLM):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.async_handler = AsyncHTTPHandler(
|
self.async_handler = get_async_httpx_client(
|
||||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||||
|
params={"timeout": 600.0},
|
||||||
)
|
)
|
||||||
|
|
||||||
def convert_response_created_at(self, response: ResponseTuningJob):
|
def convert_response_created_at(self, response: ResponseTuningJob):
|
||||||
|
|
|
@ -263,7 +263,11 @@ def get_hf_task_for_model(model: str) -> Tuple[hf_tasks, str]:
|
||||||
return "text-generation-inference", model # default to tgi
|
return "text-generation-inference", model # default to tgi
|
||||||
|
|
||||||
|
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_hf_task_embedding_for_model(
|
def get_hf_task_embedding_for_model(
|
||||||
|
@ -301,7 +305,9 @@ async def async_get_hf_task_embedding_for_model(
|
||||||
task_type, hf_tasks_embeddings
|
task_type, hf_tasks_embeddings
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
http_client = AsyncHTTPHandler(concurrent_limit=1)
|
http_client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.HUGGINGFACE,
|
||||||
|
)
|
||||||
|
|
||||||
model_info = await http_client.get(url=api_base)
|
model_info = await http_client.get(url=api_base)
|
||||||
|
|
||||||
|
@ -1067,7 +1073,9 @@ class Huggingface(BaseLLM):
|
||||||
)
|
)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
if client is None:
|
if client is None:
|
||||||
client = AsyncHTTPHandler(concurrent_limit=1)
|
client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.HUGGINGFACE,
|
||||||
|
)
|
||||||
|
|
||||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,10 @@ class OpenAILikeEmbeddingHandler(OpenAILikeBase):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||||
self.async_client = AsyncHTTPHandler(timeout=timeout) # type: ignore
|
self.async_client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.OPENAI,
|
||||||
|
params={"timeout": timeout},
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.async_client = client
|
self.async_client = client
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,10 @@ import litellm.litellm_core_utils
|
||||||
import litellm.litellm_core_utils.litellm_logging
|
import litellm.litellm_core_utils.litellm_logging
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
|
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
|
||||||
|
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
|
@ -549,7 +552,10 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
headers={},
|
headers={},
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
|
|
||||||
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout))
|
async_handler = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.PREDIBASE,
|
||||||
|
params={"timeout": timeout},
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
response = await async_handler.post(
|
response = await async_handler.post(
|
||||||
api_base, headers=headers, data=json.dumps(data)
|
api_base, headers=headers, data=json.dumps(data)
|
||||||
|
|
|
@ -9,7 +9,10 @@ import httpx # type: ignore
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
|
||||||
from .prompt_templates.factory import custom_prompt, prompt_factory
|
from .prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
@ -325,7 +328,7 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
|
||||||
async def async_handle_prediction_response_streaming(
|
async def async_handle_prediction_response_streaming(
|
||||||
prediction_url, api_token, print_verbose
|
prediction_url, api_token, print_verbose
|
||||||
):
|
):
|
||||||
http_handler = AsyncHTTPHandler(concurrent_limit=1)
|
http_handler = get_async_httpx_client(llm_provider=litellm.LlmProviders.REPLICATE)
|
||||||
previous_output = ""
|
previous_output = ""
|
||||||
output_string = ""
|
output_string = ""
|
||||||
|
|
||||||
|
@ -560,7 +563,9 @@ async def async_completion(
|
||||||
logging_obj,
|
logging_obj,
|
||||||
print_verbose,
|
print_verbose,
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
http_handler = AsyncHTTPHandler(concurrent_limit=1)
|
http_handler = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.REPLICATE,
|
||||||
|
)
|
||||||
prediction_url = await async_start_prediction(
|
prediction_url = await async_start_prediction(
|
||||||
version_id,
|
version_id,
|
||||||
input_data,
|
input_data,
|
||||||
|
|
|
@ -18,7 +18,10 @@ import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
from litellm.types.llms.databricks import GenericStreamingChunk
|
from litellm.types.llms.databricks import GenericStreamingChunk
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
Choices,
|
Choices,
|
||||||
|
@ -479,8 +482,9 @@ class CodestralTextCompletion(BaseLLM):
|
||||||
headers={},
|
headers={},
|
||||||
) -> TextCompletionResponse:
|
) -> TextCompletionResponse:
|
||||||
|
|
||||||
async_handler = AsyncHTTPHandler(
|
async_handler = get_async_httpx_client(
|
||||||
timeout=httpx.Timeout(timeout=timeout), concurrent_limit=1
|
llm_provider=litellm.LlmProviders.TEXT_COMPLETION_CODESTRAL,
|
||||||
|
params={"timeout": timeout},
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,11 @@ import httpx # type: ignore
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
Choices,
|
Choices,
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
|
@ -50,8 +54,8 @@ class TritonChatCompletion(BaseLLM):
|
||||||
logging_obj: Any,
|
logging_obj: Any,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
) -> EmbeddingResponse:
|
) -> EmbeddingResponse:
|
||||||
async_handler = AsyncHTTPHandler(
|
async_handler = get_async_httpx_client(
|
||||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
llm_provider=litellm.LlmProviders.TRITON, params={"timeout": 600.0}
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await async_handler.post(url=api_base, data=json.dumps(data))
|
response = await async_handler.post(url=api_base, data=json.dumps(data))
|
||||||
|
@ -261,7 +265,9 @@ class TritonChatCompletion(BaseLLM):
|
||||||
model_response,
|
model_response,
|
||||||
type_of_model,
|
type_of_model,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
handler = AsyncHTTPHandler()
|
handler = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.TRITON, params={"timeout": 600.0}
|
||||||
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return self._ahandle_stream( # type: ignore
|
return self._ahandle_stream( # type: ignore
|
||||||
handler, api_base, data_for_triton, model, logging_obj
|
handler, api_base, data_for_triton, model, logging_obj
|
||||||
|
|
|
@ -1026,7 +1026,9 @@ async def make_call(
|
||||||
logging_obj,
|
logging_obj,
|
||||||
):
|
):
|
||||||
if client is None:
|
if client is None:
|
||||||
client = AsyncHTTPHandler() # Create a new client if none provided
|
client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||||
|
|
|
@ -7,8 +7,13 @@ from typing import Any, List, Literal, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
from litellm import EmbeddingResponse
|
from litellm import EmbeddingResponse
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
from litellm.types.llms.openai import EmbeddingInput
|
from litellm.types.llms.openai import EmbeddingInput
|
||||||
from litellm.types.llms.vertex_ai import (
|
from litellm.types.llms.vertex_ai import (
|
||||||
VertexAIBatchEmbeddingsRequestBody,
|
VertexAIBatchEmbeddingsRequestBody,
|
||||||
|
@ -150,7 +155,10 @@ class GoogleBatchEmbeddings(VertexLLM):
|
||||||
else:
|
else:
|
||||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
|
||||||
async_handler: AsyncHTTPHandler = AsyncHTTPHandler(**_params) # type: ignore
|
async_handler: AsyncHTTPHandler = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||||
|
params={"timeout": timeout},
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
async_handler = client # type: ignore
|
async_handler = client # type: ignore
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,11 @@ import httpx
|
||||||
from openai.types.image import Image
|
from openai.types.image import Image
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
VertexLLM,
|
VertexLLM,
|
||||||
)
|
)
|
||||||
|
@ -156,7 +160,10 @@ class VertexImageGeneration(VertexLLM):
|
||||||
else:
|
else:
|
||||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
|
||||||
self.async_handler = AsyncHTTPHandler(**_params) # type: ignore
|
self.async_handler = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||||
|
params={"timeout": timeout},
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.async_handler = client # type: ignore
|
self.async_handler = client # type: ignore
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,11 @@ import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
VertexAIError,
|
VertexAIError,
|
||||||
VertexLLM,
|
VertexLLM,
|
||||||
|
@ -172,7 +176,10 @@ class VertexMultimodalEmbedding(VertexLLM):
|
||||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
timeout = httpx.Timeout(timeout)
|
timeout = httpx.Timeout(timeout)
|
||||||
_params["timeout"] = timeout
|
_params["timeout"] = timeout
|
||||||
client = AsyncHTTPHandler(**_params) # type: ignore
|
client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||||
|
params={"timeout": timeout},
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
client = client # type: ignore
|
client = client # type: ignore
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ from pydantic import BaseModel
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
|
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
|
||||||
from litellm.llms.prompt_templates.factory import (
|
from litellm.llms.prompt_templates.factory import (
|
||||||
convert_to_anthropic_image_obj,
|
convert_to_anthropic_image_obj,
|
||||||
convert_to_gemini_tool_call_invoke,
|
convert_to_gemini_tool_call_invoke,
|
||||||
|
@ -93,11 +94,15 @@ def _get_client_cache_key(
|
||||||
|
|
||||||
|
|
||||||
def _get_client_from_cache(client_cache_key: str):
|
def _get_client_from_cache(client_cache_key: str):
|
||||||
return litellm.in_memory_llm_clients_cache.get(client_cache_key, None)
|
return litellm.in_memory_llm_clients_cache.get_cache(client_cache_key)
|
||||||
|
|
||||||
|
|
||||||
def _set_client_in_cache(client_cache_key: str, vertex_llm_model: Any):
|
def _set_client_in_cache(client_cache_key: str, vertex_llm_model: Any):
|
||||||
litellm.in_memory_llm_clients_cache[client_cache_key] = vertex_llm_model
|
litellm.in_memory_llm_clients_cache.set_cache(
|
||||||
|
key=client_cache_key,
|
||||||
|
value=vertex_llm_model,
|
||||||
|
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def completion( # noqa: PLR0915
|
def completion( # noqa: PLR0915
|
||||||
|
|
|
@ -24,7 +24,10 @@ import httpx # type: ignore
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
||||||
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
|
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
|
||||||
|
@ -710,10 +713,13 @@ class RequestManager:
|
||||||
if stream:
|
if stream:
|
||||||
request_params["stream"] = stream
|
request_params["stream"] = stream
|
||||||
try:
|
try:
|
||||||
self.async_handler = AsyncHTTPHandler(
|
self.async_handler = get_async_httpx_client(
|
||||||
timeout=httpx.Timeout(
|
llm_provider=litellm.LlmProviders.WATSONX,
|
||||||
timeout=request_params.pop("timeout", 600.0), connect=5.0
|
params={
|
||||||
),
|
"timeout": httpx.Timeout(
|
||||||
|
timeout=request_params.pop("timeout", 600.0), connect=5.0
|
||||||
|
),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
if "json" in request_params:
|
if "json" in request_params:
|
||||||
request_params["data"] = json.dumps(request_params.pop("json", {}))
|
request_params["data"] = json.dumps(request_params.pop("json", {}))
|
||||||
|
|
88
tests/code_coverage_tests/ensure_async_clients_test.py
Normal file
88
tests/code_coverage_tests/ensure_async_clients_test.py
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
import ast
|
||||||
|
import os
|
||||||
|
|
||||||
|
ALLOWED_FILES = [
|
||||||
|
# local files
|
||||||
|
"../../litellm/__init__.py",
|
||||||
|
"../../litellm/llms/custom_httpx/http_handler.py",
|
||||||
|
# when running on ci/cd
|
||||||
|
"./litellm/__init__.py",
|
||||||
|
"./litellm/llms/custom_httpx/http_handler.py",
|
||||||
|
]
|
||||||
|
|
||||||
|
warning_msg = "this is a serious violation that can impact latency. Creating Async clients per request can add +500ms per request"
|
||||||
|
|
||||||
|
|
||||||
|
def check_for_async_http_handler(file_path):
|
||||||
|
"""
|
||||||
|
Checks if AsyncHttpHandler is instantiated in the given file.
|
||||||
|
Returns a list of line numbers where AsyncHttpHandler is used.
|
||||||
|
"""
|
||||||
|
print("..checking file=", file_path)
|
||||||
|
if file_path in ALLOWED_FILES:
|
||||||
|
return []
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
try:
|
||||||
|
tree = ast.parse(file.read())
|
||||||
|
except SyntaxError:
|
||||||
|
print(f"Warning: Syntax error in file {file_path}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
violations = []
|
||||||
|
target_names = [
|
||||||
|
"AsyncHttpHandler",
|
||||||
|
"AsyncHTTPHandler",
|
||||||
|
"AsyncClient",
|
||||||
|
"httpx.AsyncClient",
|
||||||
|
] # Add variations here
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, ast.Call):
|
||||||
|
if isinstance(node.func, ast.Name) and node.func.id.lower() in [
|
||||||
|
name.lower() for name in target_names
|
||||||
|
]:
|
||||||
|
raise ValueError(
|
||||||
|
f"found violation in file {file_path} line: {node.lineno}. Please use `get_async_httpx_client` instead. {warning_msg}"
|
||||||
|
)
|
||||||
|
return violations
|
||||||
|
|
||||||
|
|
||||||
|
def scan_directory_for_async_handler(base_dir):
|
||||||
|
"""
|
||||||
|
Scans all Python files in the directory tree for AsyncHttpHandler usage.
|
||||||
|
Returns a dict of files and line numbers where violations were found.
|
||||||
|
"""
|
||||||
|
violations = {}
|
||||||
|
|
||||||
|
for root, _, files in os.walk(base_dir):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith(".py"):
|
||||||
|
file_path = os.path.join(root, file)
|
||||||
|
file_violations = check_for_async_http_handler(file_path)
|
||||||
|
if file_violations:
|
||||||
|
violations[file_path] = file_violations
|
||||||
|
|
||||||
|
return violations
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_async_http_handler_usage():
|
||||||
|
"""
|
||||||
|
Test to ensure AsyncHttpHandler is not used anywhere in the codebase.
|
||||||
|
"""
|
||||||
|
base_dir = "./litellm" # Adjust this path as needed
|
||||||
|
|
||||||
|
# base_dir = "../../litellm" # LOCAL TESTING
|
||||||
|
violations = scan_directory_for_async_handler(base_dir)
|
||||||
|
|
||||||
|
if violations:
|
||||||
|
violation_messages = []
|
||||||
|
for file_path, line_numbers in violations.items():
|
||||||
|
violation_messages.append(
|
||||||
|
f"Found AsyncHttpHandler in {file_path} at lines: {line_numbers}"
|
||||||
|
)
|
||||||
|
raise AssertionError(
|
||||||
|
"AsyncHttpHandler usage detected:\n" + "\n".join(violation_messages)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_no_async_http_handler_usage()
|
|
@ -8,6 +8,7 @@ import traceback
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from openai.types.image import Image
|
from openai.types.image import Image
|
||||||
|
from litellm.caching import InMemoryCache
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
@ -107,7 +108,7 @@ class TestVertexImageGeneration(BaseImageGenTest):
|
||||||
# comment this when running locally
|
# comment this when running locally
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
|
|
||||||
litellm.in_memory_llm_clients_cache = {}
|
litellm.in_memory_llm_clients_cache = InMemoryCache()
|
||||||
return {
|
return {
|
||||||
"model": "vertex_ai/imagegeneration@006",
|
"model": "vertex_ai/imagegeneration@006",
|
||||||
"vertex_ai_project": "adroit-crow-413218",
|
"vertex_ai_project": "adroit-crow-413218",
|
||||||
|
@ -118,13 +119,13 @@ class TestVertexImageGeneration(BaseImageGenTest):
|
||||||
|
|
||||||
class TestBedrockSd3(BaseImageGenTest):
|
class TestBedrockSd3(BaseImageGenTest):
|
||||||
def get_base_image_generation_call_args(self) -> dict:
|
def get_base_image_generation_call_args(self) -> dict:
|
||||||
litellm.in_memory_llm_clients_cache = {}
|
litellm.in_memory_llm_clients_cache = InMemoryCache()
|
||||||
return {"model": "bedrock/stability.sd3-large-v1:0"}
|
return {"model": "bedrock/stability.sd3-large-v1:0"}
|
||||||
|
|
||||||
|
|
||||||
class TestBedrockSd1(BaseImageGenTest):
|
class TestBedrockSd1(BaseImageGenTest):
|
||||||
def get_base_image_generation_call_args(self) -> dict:
|
def get_base_image_generation_call_args(self) -> dict:
|
||||||
litellm.in_memory_llm_clients_cache = {}
|
litellm.in_memory_llm_clients_cache = InMemoryCache()
|
||||||
return {"model": "bedrock/stability.sd3-large-v1:0"}
|
return {"model": "bedrock/stability.sd3-large-v1:0"}
|
||||||
|
|
||||||
|
|
||||||
|
@ -181,7 +182,7 @@ def test_image_generation_azure_dall_e_3():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_aimage_generation_bedrock_with_optional_params():
|
async def test_aimage_generation_bedrock_with_optional_params():
|
||||||
try:
|
try:
|
||||||
litellm.in_memory_llm_clients_cache = {}
|
litellm.in_memory_llm_clients_cache = InMemoryCache()
|
||||||
response = await litellm.aimage_generation(
|
response = await litellm.aimage_generation(
|
||||||
prompt="A cute baby sea otter",
|
prompt="A cute baby sea otter",
|
||||||
model="bedrock/stability.stable-diffusion-xl-v1",
|
model="bedrock/stability.stable-diffusion-xl-v1",
|
||||||
|
|
|
@ -12,6 +12,7 @@ sys.path.insert(0, os.path.abspath("../.."))
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
from litellm.caching import InMemoryCache
|
||||||
|
|
||||||
litellm.num_retries = 3
|
litellm.num_retries = 3
|
||||||
litellm.success_callback = ["langfuse"]
|
litellm.success_callback = ["langfuse"]
|
||||||
|
@ -29,15 +30,20 @@ def langfuse_client():
|
||||||
f"{os.environ['LANGFUSE_PUBLIC_KEY']}-{os.environ['LANGFUSE_SECRET_KEY']}"
|
f"{os.environ['LANGFUSE_PUBLIC_KEY']}-{os.environ['LANGFUSE_SECRET_KEY']}"
|
||||||
)
|
)
|
||||||
# use a in memory langfuse client for testing, RAM util on ci/cd gets too high when we init many langfuse clients
|
# use a in memory langfuse client for testing, RAM util on ci/cd gets too high when we init many langfuse clients
|
||||||
if _langfuse_cache_key in litellm.in_memory_llm_clients_cache:
|
|
||||||
langfuse_client = litellm.in_memory_llm_clients_cache[_langfuse_cache_key]
|
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_langfuse_cache_key)
|
||||||
|
if _cached_client:
|
||||||
|
langfuse_client = _cached_client
|
||||||
else:
|
else:
|
||||||
langfuse_client = langfuse.Langfuse(
|
langfuse_client = langfuse.Langfuse(
|
||||||
public_key=os.environ["LANGFUSE_PUBLIC_KEY"],
|
public_key=os.environ["LANGFUSE_PUBLIC_KEY"],
|
||||||
secret_key=os.environ["LANGFUSE_SECRET_KEY"],
|
secret_key=os.environ["LANGFUSE_SECRET_KEY"],
|
||||||
host=None,
|
host=None,
|
||||||
)
|
)
|
||||||
litellm.in_memory_llm_clients_cache[_langfuse_cache_key] = langfuse_client
|
litellm.in_memory_llm_clients_cache.set_cache(
|
||||||
|
key=_langfuse_cache_key,
|
||||||
|
value=langfuse_client,
|
||||||
|
)
|
||||||
|
|
||||||
print("NEW LANGFUSE CLIENT")
|
print("NEW LANGFUSE CLIENT")
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue