forked from phoenix/litellm-mirror
fix(sagemaker.py): support 'model_id' param for sagemaker
allow passing inference component param to sagemaker in the same format as we handle this for bedrock
This commit is contained in:
parent
90c007fc69
commit
d547944556
4 changed files with 47 additions and 11 deletions
|
@ -20,7 +20,28 @@ os.environ["AWS_SECRET_ACCESS_KEY"] = ""
|
|||
os.environ["AWS_REGION_NAME"] = ""
|
||||
|
||||
response = completion(
|
||||
model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b",
|
||||
model="sagemaker/<your-endpoint-name>",
|
||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||
temperature=0.2,
|
||||
max_tokens=80
|
||||
)
|
||||
```
|
||||
|
||||
### Passing Inference Component Name
|
||||
|
||||
If you have multiple models on an endpoint, you'll need to specify the individual model names, do this via `model_id`.
|
||||
|
||||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = ""
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
|
||||
os.environ["AWS_REGION_NAME"] = ""
|
||||
|
||||
response = completion(
|
||||
model="sagemaker/<your-endpoint-name>",
|
||||
model_id="<your-model-name",
|
||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||
temperature=0.2,
|
||||
max_tokens=80
|
||||
|
|
|
@ -166,6 +166,7 @@ def completion(
|
|||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
model_id = optional_params.pop("model_id", None)
|
||||
|
||||
if aws_access_key_id != None:
|
||||
# uses auth params passed to completion
|
||||
|
@ -288,12 +289,21 @@ def completion(
|
|||
)
|
||||
## COMPLETION CALL
|
||||
try:
|
||||
response = client.invoke_endpoint(
|
||||
EndpointName=model,
|
||||
ContentType="application/json",
|
||||
Body=data,
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
if model_id is not None:
|
||||
response = client.invoke_endpoint(
|
||||
EndpointName=model,
|
||||
InferenceComponentName=model_id,
|
||||
ContentType="application/json",
|
||||
Body=data,
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
else:
|
||||
response = client.invoke_endpoint(
|
||||
EndpointName=model,
|
||||
ContentType="application/json",
|
||||
Body=data,
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
except Exception as e:
|
||||
status_code = (
|
||||
getattr(e, "response", {})
|
||||
|
@ -303,6 +313,8 @@ def completion(
|
|||
error_message = (
|
||||
getattr(e, "response", {}).get("Error", {}).get("Message", str(e))
|
||||
)
|
||||
if "Inference Component Name header is required" in error_message:
|
||||
error_message += "\n pass in via `litellm.completion(..., model_id={InferenceComponentName})`"
|
||||
raise SagemakerError(status_code=status_code, message=error_message)
|
||||
|
||||
response = response["Body"].read().decode("utf8")
|
||||
|
|
|
@ -671,7 +671,7 @@ def completion(
|
|||
elif (
|
||||
input_cost_per_second is not None
|
||||
): # time based pricing just needs cost in place
|
||||
output_cost_per_second = output_cost_per_second or 0.0
|
||||
output_cost_per_second = output_cost_per_second
|
||||
litellm.register_model(
|
||||
{
|
||||
f"{custom_llm_provider}/{model}": {
|
||||
|
@ -2796,7 +2796,7 @@ def embedding(
|
|||
or get_secret("OLLAMA_API_BASE")
|
||||
or "http://localhost:11434"
|
||||
)
|
||||
if isinstance(input ,str):
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
if not all(isinstance(item, str) for item in input):
|
||||
raise litellm.BadRequestError(
|
||||
|
|
|
@ -1720,16 +1720,19 @@ def test_customprompt_together_ai():
|
|||
# test_customprompt_together_ai()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="AWS Suspended Account")
|
||||
def test_completion_sagemaker():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
print("testing sagemaker")
|
||||
response = completion(
|
||||
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
|
||||
model="sagemaker/jumpstart-dft-hf-llm-mistral-7b-ins-20240329-150233",
|
||||
model_id="huggingface-llm-mistral-7b-instruct-20240329-150233",
|
||||
messages=messages,
|
||||
temperature=0.2,
|
||||
max_tokens=80,
|
||||
aws_region_name=os.getenv("AWS_REGION_NAME_2"),
|
||||
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID_2"),
|
||||
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY_2"),
|
||||
input_cost_per_second=0.000420,
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue