mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(proxy_server.py): working /v1/messages
with config.yaml
Adds async router support for adapter_completion call
This commit is contained in:
parent
2f8dbbeb97
commit
31829855c0
7 changed files with 362 additions and 14 deletions
|
@ -423,7 +423,7 @@ class AnthropicConfig:
|
|||
type="tool_use",
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
input=tool_call.function.arguments,
|
||||
input=json.loads(tool_call.function.arguments),
|
||||
)
|
||||
)
|
||||
elif choice.message.content is not None:
|
||||
|
|
|
@ -3948,6 +3948,36 @@ def text_completion(
|
|||
###### Adapter Completion ################
|
||||
|
||||
|
||||
async def aadapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
|
||||
"""
|
||||
Implemented to handle async calls for adapter_completion()
|
||||
"""
|
||||
try:
|
||||
translation_obj: Optional[CustomLogger] = None
|
||||
for item in litellm.adapters:
|
||||
if item["id"] == adapter_id:
|
||||
translation_obj = item["adapter"]
|
||||
|
||||
if translation_obj is None:
|
||||
raise ValueError(
|
||||
"No matching adapter given. Received 'adapter_id'={}, litellm.adapters={}".format(
|
||||
adapter_id, litellm.adapters
|
||||
)
|
||||
)
|
||||
|
||||
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
|
||||
|
||||
response: ModelResponse = await acompletion(**new_kwargs) # type: ignore
|
||||
|
||||
translated_response = translation_obj.translate_completion_output_params(
|
||||
response=response
|
||||
)
|
||||
|
||||
return translated_response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def adapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
|
||||
translation_obj: Optional[CustomLogger] = None
|
||||
for item in litellm.adapters:
|
||||
|
|
|
@ -2,7 +2,9 @@ model_list:
|
|||
- model_name: "*"
|
||||
litellm_params:
|
||||
model: "openai/*"
|
||||
|
||||
- model_name: claude-3-5-sonnet-20240620
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
|
||||
|
||||
general_settings:
|
||||
|
|
|
@ -5045,23 +5045,187 @@ async def moderations(
|
|||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=AnthropicResponse,
|
||||
)
|
||||
async def anthropic_response(data: AnthropicMessagesRequest):
|
||||
async def anthropic_response(
|
||||
anthropic_data: AnthropicMessagesRequest,
|
||||
fastapi_response: Response,
|
||||
request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
from litellm import adapter_completion
|
||||
from litellm.adapters.anthropic_adapter import anthropic_adapter
|
||||
|
||||
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
|
||||
|
||||
response: Optional[BaseModel] = adapter_completion(adapter_id="anthropic", **data)
|
||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||
data: dict = {**anthropic_data, "adapter_id": "anthropic"}
|
||||
try:
|
||||
data["model"] = (
|
||||
general_settings.get("completion_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
or data["model"] # default passed in http request
|
||||
)
|
||||
if user_model:
|
||||
data["model"] = user_model
|
||||
|
||||
if response is None:
|
||||
raise Exception("Response is None.")
|
||||
elif not isinstance(response, AnthropicResponse):
|
||||
raise Exception(
|
||||
"Invalid model response={}. Not in 'AnthropicResponse' format".format(
|
||||
response
|
||||
data = await add_litellm_data_to_request(
|
||||
data=data, # type: ignore
|
||||
request=request,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
version=version,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
# override with user settings, these are params passed via cli
|
||||
if user_temperature:
|
||||
data["temperature"] = user_temperature
|
||||
if user_request_timeout:
|
||||
data["request_timeout"] = user_request_timeout
|
||||
if user_max_tokens:
|
||||
data["max_tokens"] = user_max_tokens
|
||||
if user_api_base:
|
||||
data["api_base"] = user_api_base
|
||||
|
||||
### MODEL ALIAS MAPPING ###
|
||||
# check if model name in model alias map
|
||||
# get the actual model name
|
||||
if data["model"] in litellm.model_alias_map:
|
||||
data["model"] = litellm.model_alias_map[data["model"]]
|
||||
|
||||
### CALL HOOKS ### - modify incoming data before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
||||
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
|
||||
)
|
||||
|
||||
### ROUTE THE REQUESTs ###
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
# skip router if user passed their key
|
||||
if "api_key" in data:
|
||||
llm_response = asyncio.create_task(litellm.aadapter_completion(**data))
|
||||
elif (
|
||||
llm_router is not None and data["model"] in router_model_names
|
||||
): # model in router model list
|
||||
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
|
||||
elif (
|
||||
llm_router is not None
|
||||
and llm_router.model_group_alias is not None
|
||||
and data["model"] in llm_router.model_group_alias
|
||||
): # model set in model_group_alias
|
||||
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
|
||||
elif (
|
||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
llm_response = asyncio.create_task(
|
||||
llm_router.aadapter_completion(**data, specific_deployment=True)
|
||||
)
|
||||
elif (
|
||||
llm_router is not None and data["model"] in llm_router.get_model_ids()
|
||||
): # model in router model list
|
||||
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
|
||||
elif (
|
||||
llm_router is not None
|
||||
and data["model"] not in router_model_names
|
||||
and llm_router.default_deployment is not None
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
|
||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||
llm_response = asyncio.create_task(litellm.aadapter_completion(**data))
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"error": "completion: Invalid model name passed in model="
|
||||
+ data.get("model", "")
|
||||
},
|
||||
)
|
||||
|
||||
# Await the llm_response task
|
||||
response = await llm_response
|
||||
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
response_cost = hidden_params.get("response_cost", None) or ""
|
||||
|
||||
### ALERTING ###
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.update_request_status(
|
||||
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||
)
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("final response: %s", response)
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response))
|
||||
return response
|
||||
except RejectedRequestError as e:
|
||||
_data = e.request_data
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data=_data,
|
||||
)
|
||||
if _data.get("stream", None) is not None and _data["stream"] == True:
|
||||
_chat_response = litellm.ModelResponse()
|
||||
_usage = litellm.Usage(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
_chat_response.usage = _usage # type: ignore
|
||||
_chat_response.choices[0].message.content = e.message # type: ignore
|
||||
_iterator = litellm.utils.ModelResponseIterator(
|
||||
model_response=_chat_response, convert_to_delta=True
|
||||
)
|
||||
_streaming_response = litellm.TextCompletionStreamWrapper(
|
||||
completion_stream=_iterator,
|
||||
model=_data.get("model", ""),
|
||||
)
|
||||
|
||||
selected_data_generator = select_data_generator(
|
||||
response=_streaming_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=data,
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
selected_data_generator,
|
||||
media_type="text/event-stream",
|
||||
headers={},
|
||||
)
|
||||
else:
|
||||
_response = litellm.TextCompletionResponse()
|
||||
_response.choices[0].text = e.message
|
||||
return _response
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.completion(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
|
||||
|
||||
#### DEV UTILS ####
|
||||
|
|
|
@ -1764,6 +1764,125 @@ class Router:
|
|||
self.fail_calls[model] += 1
|
||||
raise e
|
||||
|
||||
async def aadapter_completion(
|
||||
self,
|
||||
adapter_id: str,
|
||||
model: str,
|
||||
is_retry: Optional[bool] = False,
|
||||
is_fallback: Optional[bool] = False,
|
||||
is_async: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["adapter_id"] = adapter_id
|
||||
kwargs["original_function"] = self._aadapter_completion
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
timeout = kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
response = await self.async_function_with_fallbacks(**kwargs)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
send_llm_exception_alert(
|
||||
litellm_router_instance=self,
|
||||
request_kwargs=kwargs,
|
||||
error_traceback_str=traceback.format_exc(),
|
||||
original_exception=e,
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
async def _aadapter_completion(self, adapter_id: str, model: str, **kwargs):
|
||||
try:
|
||||
verbose_router_logger.debug(
|
||||
f"Inside _aadapter_completion()- model: {model}; kwargs: {kwargs}"
|
||||
)
|
||||
deployment = await self.async_get_available_deployment(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "default text"}],
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
)
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{
|
||||
"deployment": deployment["litellm_params"]["model"],
|
||||
"model_info": deployment.get("model_info", {}),
|
||||
"api_base": deployment.get("litellm_params", {}).get("api_base"),
|
||||
}
|
||||
)
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
data = deployment["litellm_params"].copy()
|
||||
model_name = data["model"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if (
|
||||
k not in kwargs
|
||||
): # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
|
||||
potential_model_client = self._get_client(
|
||||
deployment=deployment, kwargs=kwargs, client_type="async"
|
||||
)
|
||||
# check if provided keys == client keys #
|
||||
dynamic_api_key = kwargs.get("api_key", None)
|
||||
if (
|
||||
dynamic_api_key is not None
|
||||
and potential_model_client is not None
|
||||
and dynamic_api_key != potential_model_client.api_key
|
||||
):
|
||||
model_client = None
|
||||
else:
|
||||
model_client = potential_model_client
|
||||
self.total_calls[model_name] += 1
|
||||
|
||||
response = litellm.aadapter_completion(
|
||||
**{
|
||||
**data,
|
||||
"adapter_id": adapter_id,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
"timeout": self.timeout,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
rpm_semaphore = self._get_client(
|
||||
deployment=deployment,
|
||||
kwargs=kwargs,
|
||||
client_type="max_parallel_requests",
|
||||
)
|
||||
|
||||
if rpm_semaphore is not None and isinstance(
|
||||
rpm_semaphore, asyncio.Semaphore
|
||||
):
|
||||
async with rpm_semaphore:
|
||||
"""
|
||||
- Check rpm limits before making the call
|
||||
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
|
||||
"""
|
||||
await self.async_routing_strategy_pre_call_checks(
|
||||
deployment=deployment
|
||||
)
|
||||
response = await response # type: ignore
|
||||
else:
|
||||
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
|
||||
response = await response # type: ignore
|
||||
|
||||
self.success_calls[model_name] += 1
|
||||
verbose_router_logger.info(
|
||||
f"litellm.aadapter_completion(model={model_name})\033[32m 200 OK\033[0m"
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
verbose_router_logger.info(
|
||||
f"litellm.aadapter_completion(model={model})\033[31m Exception {str(e)}\033[0m"
|
||||
)
|
||||
if model is not None:
|
||||
self.fail_calls[model] += 1
|
||||
raise e
|
||||
|
||||
def embedding(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -20,7 +20,7 @@ from unittest.mock import MagicMock, patch
|
|||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm import AnthropicConfig, adapter_completion
|
||||
from litellm import AnthropicConfig, Router, adapter_completion
|
||||
from litellm.adapters.anthropic_adapter import anthropic_adapter
|
||||
from litellm.types.llms.anthropic import AnthropicResponse
|
||||
|
||||
|
@ -67,4 +67,37 @@ def test_anthropic_completion_e2e():
|
|||
|
||||
assert isinstance(response, AnthropicResponse)
|
||||
|
||||
assert False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_router_completion_e2e():
|
||||
litellm.set_verbose = True
|
||||
|
||||
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "claude-3-5-sonnet-20240620",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"mock_response": "hi this is macintosh.",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
|
||||
response = await router.aadapter_completion(
|
||||
model="claude-3-5-sonnet-20240620",
|
||||
messages=messages,
|
||||
adapter_id="anthropic",
|
||||
mock_response="This is a fake call",
|
||||
)
|
||||
|
||||
print("Response: {}".format(response))
|
||||
|
||||
assert response is not None
|
||||
|
||||
assert isinstance(response, AnthropicResponse)
|
||||
|
||||
assert response.model == "gpt-3.5-turbo"
|
||||
|
|
|
@ -234,7 +234,7 @@ class AnthropicResponseContentBlockToolUse(BaseModel):
|
|||
type: Literal["tool_use"]
|
||||
id: str
|
||||
name: str
|
||||
input: str
|
||||
input: dict
|
||||
|
||||
|
||||
class AnthropicResponseUsageBlock(BaseModel):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue