forked from phoenix/litellm-mirror
add test for ft endpoints on azure
This commit is contained in:
parent
c8dfc95e90
commit
287b09cff6
4 changed files with 74 additions and 5 deletions
|
@ -40,11 +40,20 @@ fine_tuning_config = None
|
||||||
|
|
||||||
def set_fine_tuning_config(config):
|
def set_fine_tuning_config(config):
|
||||||
global fine_tuning_config
|
global fine_tuning_config
|
||||||
|
if not isinstance(config, list):
|
||||||
|
raise ValueError("invalid fine_tuning config, expected a list is not a list")
|
||||||
|
|
||||||
|
for element in config:
|
||||||
|
if isinstance(element, dict):
|
||||||
|
for key, value in element.items():
|
||||||
|
if isinstance(value, str) and value.startswith("os.environ/"):
|
||||||
|
element[key] = litellm.get_secret(value)
|
||||||
|
|
||||||
fine_tuning_config = config
|
fine_tuning_config = config
|
||||||
|
|
||||||
|
|
||||||
# Function to search for specific custom_llm_provider and return its configuration
|
# Function to search for specific custom_llm_provider and return its configuration
|
||||||
def get_provider_config(
|
def get_fine_tuning_provider_config(
|
||||||
custom_llm_provider: str,
|
custom_llm_provider: str,
|
||||||
):
|
):
|
||||||
global fine_tuning_config
|
global fine_tuning_config
|
||||||
|
@ -122,7 +131,7 @@ async def create_fine_tuning_job(
|
||||||
)
|
)
|
||||||
|
|
||||||
# get configs for custom_llm_provider
|
# get configs for custom_llm_provider
|
||||||
llm_provider_config = get_provider_config(
|
llm_provider_config = get_fine_tuning_provider_config(
|
||||||
custom_llm_provider=fine_tuning_request.custom_llm_provider,
|
custom_llm_provider=fine_tuning_request.custom_llm_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,34 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
files_config = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_files_config(config):
|
||||||
|
global files_config
|
||||||
|
if not isinstance(config, list):
|
||||||
|
raise ValueError("invalid files config, expected a list is not a list")
|
||||||
|
|
||||||
|
for element in config:
|
||||||
|
if isinstance(element, dict):
|
||||||
|
for key, value in element.items():
|
||||||
|
if isinstance(value, str) and value.startswith("os.environ/"):
|
||||||
|
element[key] = litellm.get_secret(value)
|
||||||
|
|
||||||
|
files_config = config
|
||||||
|
|
||||||
|
|
||||||
|
def get_files_provider_config(
|
||||||
|
custom_llm_provider: str,
|
||||||
|
):
|
||||||
|
global files_config
|
||||||
|
if files_config is None:
|
||||||
|
raise ValueError("files_config is not set, set it on your config.yaml file.")
|
||||||
|
for setting in files_config:
|
||||||
|
if setting.get("custom_llm_provider") == custom_llm_provider:
|
||||||
|
return setting
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/v1/files",
|
"/v1/files",
|
||||||
|
@ -49,6 +77,7 @@ async def create_file(
|
||||||
request: Request,
|
request: Request,
|
||||||
fastapi_response: Response,
|
fastapi_response: Response,
|
||||||
purpose: str = Form(...),
|
purpose: str = Form(...),
|
||||||
|
custom_llm_provider: str = Form(...),
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
|
@ -100,11 +129,17 @@ async def create_file(
|
||||||
|
|
||||||
_create_file_request = CreateFileRequest(file=file_data, **data)
|
_create_file_request = CreateFileRequest(file=file_data, **data)
|
||||||
|
|
||||||
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
# get configs for custom_llm_provider
|
||||||
response = await litellm.acreate_file(
|
llm_provider_config = get_files_provider_config(
|
||||||
custom_llm_provider="openai", **_create_file_request
|
custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# add llm_provider_config to data
|
||||||
|
_create_file_request.update(llm_provider_config)
|
||||||
|
|
||||||
|
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
||||||
|
response = await litellm.acreate_file(**_create_file_request)
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
proxy_logging_obj.update_request_status(
|
proxy_logging_obj.update_request_status(
|
||||||
|
|
2
tests/openai_batch_completions.jsonl
Normal file
2
tests/openai_batch_completions.jsonl
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
||||||
|
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
23
tests/test_openai_fine_tuning.py
Normal file
23
tests/test_openai_fine_tuning.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_fine_tuning():
|
||||||
|
"""
|
||||||
|
[PROD Test] Ensures logprobs are returned correctly
|
||||||
|
"""
|
||||||
|
client = AsyncOpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||||
|
|
||||||
|
file_name = "openai_batch_completions.jsonl"
|
||||||
|
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
file_path = os.path.join(_current_dir, file_name)
|
||||||
|
|
||||||
|
response = await client.files.create(
|
||||||
|
extra_body={"custom_llm_provider": "azure"},
|
||||||
|
file=open(file_path, "rb"),
|
||||||
|
purpose="fine-tune",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("response from files.create: {}".format(response))
|
Loading…
Add table
Add a link
Reference in a new issue