Merge pull request #5004 from BerriAI/litellm_codestral_fim_support

feat(vertex_ai_partner.py): add vertex ai codestral FIM support
This commit is contained in:
Krish Dholakia 2024-08-01 21:24:12 -07:00 committed by GitHub
commit d8778380d8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 199 additions and 40 deletions

View file

@ -833,7 +833,11 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
| Model Name | Function Call |
|------------------|--------------------------------------|
| meta/llama3-405b-instruct-maas | `completion('vertex_ai/mistral-large@2407', messages)` |
| mistral-large@latest | `completion('vertex_ai/mistral-large@latest', messages)` |
| mistral-large@2407 | `completion('vertex_ai/mistral-large@2407', messages)` |
| mistral-nemo@latest | `completion('vertex_ai/mistral-nemo@latest', messages)` |
| codestral@latest | `completion('vertex_ai/codestral@latest', messages)` |
| codestral@@2405 | `completion('vertex_ai/codestral@2405', messages)` |
### Usage
@ -866,12 +870,12 @@ print("\nModel Response", response)
```yaml
model_list:
- model_name: anthropic-mistral
- model_name: vertex-mistral
litellm_params:
model: vertex_ai/mistral-large@2407
vertex_ai_project: "my-test-project"
vertex_ai_location: "us-east-1"
- model_name: anthropic-mistral
- model_name: vertex-mistral
litellm_params:
model: vertex_ai/mistral-large@2407
vertex_ai_project: "my-test-project"
@ -893,7 +897,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "anthropic-mistral", # 👈 the 'model_name' in config
"model": "vertex-mistral", # 👈 the 'model_name' in config
"messages": [
{
"role": "user",
@ -907,6 +911,94 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
</Tabs>
### Usage - Codestral FIM
Call Codestral on VertexAI via the OpenAI [`/v1/completion`](https://platform.openai.com/docs/api-reference/completions/create) endpoint for FIM tasks.
Note: You can also call Codestral via `/chat/completion`.
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
# os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = ""
# OR run `!gcloud auth print-access-token` in your terminal
model = "codestral@2405"
vertex_ai_project = "your-vertex-project" # can also set this as os.environ["VERTEXAI_PROJECT"]
vertex_ai_location = "your-vertex-location" # can also set this as os.environ["VERTEXAI_LOCATION"]
response = text_completion(
model="vertex_ai/" + model,
vertex_ai_project=vertex_ai_project,
vertex_ai_location=vertex_ai_location,
prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
suffix="return True", # optional
temperature=0, # optional
top_p=1, # optional
max_tokens=10, # optional
min_tokens=10, # optional
seed=10, # optional
stop=["return"], # optional
)
print("\nModel Response", response)
```
</TabItem>
<TabItem value="proxy" label="Proxy">
**1. Add to config**
```yaml
model_list:
- model_name: vertex-codestral
litellm_params:
model: vertex_ai/codestral@2405
vertex_ai_project: "my-test-project"
vertex_ai_location: "us-east-1"
- model_name: vertex-codestral
litellm_params:
model: vertex_ai/codestral@2405
vertex_ai_project: "my-test-project"
vertex_ai_location: "us-west-1"
```
**2. Start proxy**
```bash
litellm --config /path/to/config.yaml
# RUNNING at http://0.0.0.0:4000
```
**3. Test it!**
```bash
curl -X POST 'http://0.0.0.0:4000/completions' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"model": "vertex-codestral", # 👈 the 'model_name' in config
"prompt": "def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
"suffix":"return True", # optional
"temperature":0, # optional
"top_p":1, # optional
"max_tokens":10, # optional
"min_tokens":10, # optional
"seed":10, # optional
"stop":["return"], # optional
}'
```
</TabItem>
</Tabs>
## Model Garden
| Model Name | Function Call |
|------------------|--------------------------------------|

View file

@ -5,6 +5,7 @@ import os
import traceback
from packaging.version import Version
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
@ -43,8 +44,8 @@ class LangFuseLogger:
self.langfuse_debug = os.getenv("LANGFUSE_DEBUG")
parameters = {
"public_key": self.public_key,
"secret_key": self.secret_key,
"public_key": "pk-lf-b3db7e8e-c2f6-4fc7-825c-a541a8fbe003",
"secret_key": "sk-lf-b11ef3a8-361c-4445-9652-12318b8596e4",
"host": self.langfuse_host,
"release": self.langfuse_release,
"debug": self.langfuse_debug,
@ -331,7 +332,7 @@ class LangFuseLogger:
metadata = copy.deepcopy(
metadata
) # Avoid modifying the original metadata
except:
except Exception:
new_metadata = {}
for key, value in metadata.items():
if (
@ -342,6 +343,8 @@ class LangFuseLogger:
or isinstance(value, float)
):
new_metadata[key] = copy.deepcopy(value)
elif isinstance(value, BaseModel):
new_metadata[key] = value.model_dump()
metadata = new_metadata
supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3")

View file

@ -1,28 +1,33 @@
# What is this?
## Controller file for TextCompletionCodestral Integration - https://codestral.com/
from functools import partial
import os, types
import traceback
import copy
import json
from enum import Enum
import requests, copy # type: ignore
import os
import time
from typing import Callable, Optional, List, Literal, Union
import traceback
import types
from enum import Enum
from functools import partial
from typing import Callable, List, Literal, Optional, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.llms.databricks import GenericStreamingChunk
from litellm.utils import (
TextCompletionResponse,
Usage,
Choices,
CustomStreamWrapper,
Message,
Choices,
TextCompletionResponse,
Usage,
)
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.types.llms.databricks import GenericStreamingChunk
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import httpx # type: ignore
from .prompt_templates.factory import custom_prompt, prompt_factory
class TextCompletionCodestralError(Exception):
@ -329,7 +334,12 @@ class CodestralTextCompletion(BaseLLM):
) -> Union[TextCompletionResponse, CustomStreamWrapper]:
headers = self._validate_environment(api_key, headers)
completion_url = api_base or "https://codestral.mistral.ai/v1/fim/completions"
if optional_params.pop("custom_endpoint", None) is True:
completion_url = api_base
else:
completion_url = (
api_base or "https://codestral.mistral.ai/v1/fim/completions"
)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
@ -426,6 +436,7 @@ class CodestralTextCompletion(BaseLLM):
return _response
### SYNC COMPLETION
else:
response = requests.post(
url=completion_url,
headers=headers,
@ -464,8 +475,11 @@ class CodestralTextCompletion(BaseLLM):
headers={},
) -> TextCompletionResponse:
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout))
async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=timeout), concurrent_limit=1
)
try:
response = await async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)

View file

@ -140,10 +140,10 @@ class VertexAIPartnerModels(BaseLLM):
custom_prompt_dict: dict,
headers: Optional[dict],
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
litellm_params=None,
logger_fn=None,
acompletion: bool = False,
client=None,
@ -154,6 +154,7 @@ class VertexAIPartnerModels(BaseLLM):
from litellm.llms.databricks import DatabricksChatCompletion
from litellm.llms.openai import OpenAIChatCompletion
from litellm.llms.text_completion_codestral import CodestralTextCompletion
from litellm.llms.vertex_httpx import VertexLLM
except Exception:
@ -178,12 +179,7 @@ class VertexAIPartnerModels(BaseLLM):
)
openai_like_chat_completions = DatabricksChatCompletion()
## Load Config
# config = litellm.VertexAILlama3.get_config()
# for k, v in config.items():
# if k not in optional_params:
# optional_params[k] = v
codestral_fim_completions = CodestralTextCompletion()
## CONSTRUCT API BASE
stream: bool = optional_params.get("stream", False) or False
@ -206,6 +202,28 @@ class VertexAIPartnerModels(BaseLLM):
model = model.split("@")[0]
if "codestral" in model and litellm_params.get("text_completion") is True:
optional_params["model"] = model
text_completion_model_response = litellm.TextCompletionResponse(
stream=stream
)
return codestral_fim_completions.completion(
model=model,
messages=messages,
api_base=api_base,
api_key=access_token,
custom_prompt_dict=custom_prompt_dict,
model_response=text_completion_model_response,
print_verbose=print_verbose,
logging_obj=logging_obj,
optional_params=optional_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout,
encoding=encoding,
)
return openai_like_chat_completions.completion(
model=model,
messages=messages,

View file

@ -986,6 +986,7 @@ def completion(
output_cost_per_second=output_cost_per_second,
output_cost_per_token=output_cost_per_token,
cooldown_time=cooldown_time,
text_completion=kwargs.get("text_completion"),
)
logging.update_environment_variables(
model=model,
@ -2085,7 +2086,7 @@ def completion(
model_response=model_response,
print_verbose=print_verbose,
optional_params=new_params,
litellm_params=litellm_params,
litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn,
encoding=encoding,
vertex_location=vertex_ai_location,

View file

@ -4104,9 +4104,19 @@ async def test_async_text_completion_chat_model_stream():
# asyncio.run(test_async_text_completion_chat_model_stream())
@pytest.mark.parametrize(
"model", ["vertex_ai/codestral@2405", "text-completion-codestral/codestral-2405"] #
)
@pytest.mark.asyncio
async def test_completion_codestral_fim_api():
async def test_completion_codestral_fim_api(model):
try:
if model == "vertex_ai/codestral@2405":
from litellm.tests.test_amazing_vertex_completion import (
load_vertex_ai_credentials,
)
load_vertex_ai_credentials()
litellm.set_verbose = True
import logging
@ -4114,7 +4124,7 @@ async def test_completion_codestral_fim_api():
verbose_logger.setLevel(level=logging.DEBUG)
response = await litellm.atext_completion(
model="text-completion-codestral/codestral-2405",
model=model,
prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
suffix="return True",
temperature=0,
@ -4137,9 +4147,19 @@ async def test_completion_codestral_fim_api():
pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize(
"model",
["vertex_ai/codestral@2405", "text-completion-codestral/codestral-2405"],
)
@pytest.mark.asyncio
async def test_completion_codestral_fim_api_stream():
async def test_completion_codestral_fim_api_stream(model):
try:
if model == "vertex_ai/codestral@2405":
from litellm.tests.test_amazing_vertex_completion import (
load_vertex_ai_credentials,
)
load_vertex_ai_credentials()
import logging
from litellm._logging import verbose_logger
@ -4148,7 +4168,7 @@ async def test_completion_codestral_fim_api_stream():
# verbose_logger.setLevel(level=logging.DEBUG)
response = await litellm.atext_completion(
model="text-completion-codestral/codestral-2405",
model=model,
prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
suffix="return True",
temperature=0,

View file

@ -2258,6 +2258,7 @@ def get_litellm_params(
output_cost_per_token=None,
output_cost_per_second=None,
cooldown_time=None,
text_completion=None,
):
litellm_params = {
"acompletion": acompletion,
@ -2281,6 +2282,7 @@ def get_litellm_params(
"output_cost_per_token": output_cost_per_token,
"output_cost_per_second": output_cost_per_second,
"cooldown_time": cooldown_time,
"text_completion": text_completion,
}
return litellm_params
@ -3127,10 +3129,15 @@ def get_optional_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.MistralConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
if "codestral" in model:
optional_params = litellm.MistralTextCompletionConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params
)
else:
optional_params = litellm.MistralConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in
supported_params = get_supported_openai_params(
@ -4239,6 +4246,10 @@ def get_supported_openai_params(
return litellm.VertexAILlama3Config().get_supported_openai_params()
if model.startswith("mistral"):
return litellm.MistralConfig().get_supported_openai_params()
if model.startswith("codestral"):
return (
litellm.MistralTextCompletionConfig().get_supported_openai_params()
)
return litellm.VertexAIConfig().get_supported_openai_params()
elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()