forked from phoenix/litellm-mirror
fix vertex fine tuning
This commit is contained in:
parent
6af0494483
commit
0ee9f0fa44
3 changed files with 29 additions and 10 deletions
|
@ -11,7 +11,11 @@ import requests # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
from litellm.types.llms.bedrock import CohereEmbeddingRequest
|
from litellm.types.llms.bedrock import CohereEmbeddingRequest
|
||||||
from litellm.utils import Choices, Message, ModelResponse, Usage
|
from litellm.utils import Choices, Message, ModelResponse, Usage
|
||||||
|
|
||||||
|
@ -71,7 +75,10 @@ async def async_embedding(
|
||||||
)
|
)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
if client is None:
|
if client is None:
|
||||||
client = AsyncHTTPHandler(concurrent_limit=1, timeout=timeout)
|
client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.COHERE,
|
||||||
|
params={"timeout": timeout},
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||||
|
|
|
@ -5,9 +5,14 @@ from typing import Any, Coroutine, Literal, Optional, Union
|
||||||
import httpx
|
import httpx
|
||||||
from openai.types.fine_tuning.fine_tuning_job import FineTuningJob, Hyperparameters
|
from openai.types.fine_tuning.fine_tuning_job import FineTuningJob, Hyperparameters
|
||||||
|
|
||||||
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.llms.base import BaseLLM
|
from litellm.llms.base import BaseLLM
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
VertexLLM,
|
VertexLLM,
|
||||||
)
|
)
|
||||||
|
@ -26,8 +31,9 @@ class VertexFineTuningAPI(VertexLLM):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.async_handler = AsyncHTTPHandler(
|
self.async_handler = get_async_httpx_client(
|
||||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||||
|
params={"timeout": 600.0},
|
||||||
)
|
)
|
||||||
|
|
||||||
def convert_response_created_at(self, response: ResponseTuningJob):
|
def convert_response_created_at(self, response: ResponseTuningJob):
|
||||||
|
|
|
@ -24,7 +24,10 @@ import httpx # type: ignore
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
||||||
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
|
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
|
||||||
|
@ -710,10 +713,13 @@ class RequestManager:
|
||||||
if stream:
|
if stream:
|
||||||
request_params["stream"] = stream
|
request_params["stream"] = stream
|
||||||
try:
|
try:
|
||||||
self.async_handler = AsyncHTTPHandler(
|
self.async_handler = get_async_httpx_client(
|
||||||
timeout=httpx.Timeout(
|
llm_provider=litellm.LlmProviders.WATSONX,
|
||||||
timeout=request_params.pop("timeout", 600.0), connect=5.0
|
params={
|
||||||
),
|
"timeout": httpx.Timeout(
|
||||||
|
timeout=request_params.pop("timeout", 600.0), connect=5.0
|
||||||
|
),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
if "json" in request_params:
|
if "json" in request_params:
|
||||||
request_params["data"] = json.dumps(request_params.pop("json", {}))
|
request_params["data"] = json.dumps(request_params.pop("json", {}))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue