fix vertex fine tuning

This commit is contained in:
Ishaan Jaff 2024-11-21 10:20:16 -08:00
parent 6af0494483
commit 0ee9f0fa44
3 changed files with 29 additions and 10 deletions

View file

@ -11,7 +11,11 @@ import requests # type: ignore
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.types.llms.bedrock import CohereEmbeddingRequest
from litellm.utils import Choices, Message, ModelResponse, Usage
@ -71,7 +75,10 @@ async def async_embedding(
)
## COMPLETION CALL
if client is None:
client = AsyncHTTPHandler(concurrent_limit=1, timeout=timeout)
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.COHERE,
params={"timeout": timeout},
)
try:
response = await client.post(api_base, headers=headers, data=json.dumps(data))

View file

@ -5,9 +5,14 @@ from typing import Any, Coroutine, Literal, Optional, Union
import httpx
from openai.types.fine_tuning.fine_tuning_job import FineTuningJob, Hyperparameters
import litellm
from litellm._logging import verbose_logger
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
@ -26,8 +31,9 @@ class VertexFineTuningAPI(VertexLLM):
def __init__(self) -> None:
super().__init__()
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
self.async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
params={"timeout": 600.0},
)
def convert_response_created_at(self, response: ResponseTuningJob):

View file

@ -24,7 +24,10 @@ import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.watsonx import WatsonXAIEndpoint
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
@ -710,10 +713,13 @@ class RequestManager:
if stream:
request_params["stream"] = stream
try:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(
timeout=request_params.pop("timeout", 600.0), connect=5.0
),
self.async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.WATSONX,
params={
"timeout": httpx.Timeout(
timeout=request_params.pop("timeout", 600.0), connect=5.0
),
},
)
if "json" in request_params:
request_params["data"] = json.dumps(request_params.pop("json", {}))