feat(vertex_ai_partner.py): add vertex ai codestral FIM support

Closes https://github.com/BerriAI/litellm/issues/4984
This commit is contained in:
Krrish Dholakia 2024-08-01 17:10:27 -07:00
parent 246b3227a9
commit 010d5ed81d
5 changed files with 96 additions and 32 deletions

View file

@ -1,28 +1,33 @@
# What is this?
## Controller file for TextCompletionCodestral Integration - https://codestral.com/
from functools import partial
import os, types
import traceback
import copy
import json
from enum import Enum
import requests, copy # type: ignore
import os
import time
from typing import Callable, Optional, List, Literal, Union
import traceback
import types
from enum import Enum
from functools import partial
from typing import Callable, List, Literal, Optional, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.llms.databricks import GenericStreamingChunk
from litellm.utils import (
TextCompletionResponse,
Usage,
Choices,
CustomStreamWrapper,
Message,
Choices,
TextCompletionResponse,
Usage,
)
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.types.llms.databricks import GenericStreamingChunk
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import httpx # type: ignore
from .prompt_templates.factory import custom_prompt, prompt_factory
class TextCompletionCodestralError(Exception):
@ -329,7 +334,12 @@ class CodestralTextCompletion(BaseLLM):
) -> Union[TextCompletionResponse, CustomStreamWrapper]:
headers = self._validate_environment(api_key, headers)
completion_url = api_base or "https://codestral.mistral.ai/v1/fim/completions"
if optional_params.pop("custom_endpoint", None) is True:
completion_url = api_base
else:
completion_url = (
api_base or "https://codestral.mistral.ai/v1/fim/completions"
)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
@ -426,6 +436,7 @@ class CodestralTextCompletion(BaseLLM):
return _response
### SYNC COMPLETION
else:
response = requests.post(
url=completion_url,
headers=headers,
@ -464,8 +475,11 @@ class CodestralTextCompletion(BaseLLM):
headers={},
) -> TextCompletionResponse:
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout))
async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=timeout), concurrent_limit=1
)
try:
response = await async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)