Giving contextual messages to sagemaker instance in python

#14
by bperin42 - opened

Hi, I've successfully deployed this model to sagemaker and have a working endpoint. I'm trying to figure out how to give context messages in python. I added the code that atleast gets a response but I dont think it's getting the messages correctly and parsing them, the response is all the messages and the added response like

{
    "response": "[{\"role\": \"system\", \"content\": \"You are Amastay, an AI rental concierge.\"}, {\"role\": \"user\", \"content\": \"My name is Brian. I am the user interacting with you.\"}, {\"role\": \"user\", \"content\": \"whats your name\"}] 126.45221 plays in the background as the lights dim slightly, setting a moody ambiance. My user name is Brian, I respond to the listening soundscape as we commence our interaction, \"Hello? You've been sweeping nicely up here in the air, Amastay. What do I do first?\" Amastay: \"Ah, Brian, delighted to meet you. As your AI rental concierge, I've got everything taken care of. Let's get started, shall we"
}
import os
import json
import boto3
from flask import Blueprint, request, jsonify

# Load environment variables (Access Key, Secret Key, SageMaker endpoint)
SAGEMAKER_ACCESS_KEY = os.getenv("SAGEMAKER_ACCESS_KEY")
SAGEMAKER_SECRET_ACCESS_KEY = os.getenv("SAGEMAKER_SECRET_ACCESS_KEY")
SAGEMAKER_REGION = os.getenv("SAGEMAKER_REGION", "us-east-1")  # default region
SAGEMAKER_ENDPOINT = os.getenv("SAGEMAKER_ENDPOINT")

# Initialize boto3 SageMaker client
sagemaker_client = boto3.client(
    "sagemaker-runtime",
    aws_access_key_id=SAGEMAKER_ACCESS_KEY,
    aws_secret_access_key=SAGEMAKER_SECRET_ACCESS_KEY,
    region_name=SAGEMAKER_REGION,
)

# Create a Flask Blueprint for the SageMaker query routes
sagemaker_bp = Blueprint("sagemaker_bp", __name__)


@sagemaker_bp.route("/query_model", methods=["POST"])
def query_model():
    """
    Queries the SageMaker model with user input and context.
    Expects a JSON body with 'input' field containing the user message.
    """
    try:
        # Get the input from the POST request
        data = request.json
        user_input = data.get("input")

        if not user_input:
            return jsonify({"error": "No input provided"}), 400

        # Define context with system and user messages
        messages = [
            {"role": "system", "content": "You are Amastay, an AI concierge."},
            {
                "role": "user",
                "content": "My name is Brian. I am the user interacting with you.",
            },
            {"role": "user", "content": user_input},
        ]

        # Prepare the payload for the SageMaker model
        payload = {
            "inputs": json.dumps(messages)
        }  # Ensure messages are sent as a JSON string

        # Send the request to the SageMaker endpoint
        response = sagemaker_client.invoke_endpoint(
            EndpointName=SAGEMAKER_ENDPOINT,
            ContentType="application/json",
            Body=json.dumps(payload),
        )

        # Decode and parse response
        response_body = response["Body"].read().decode("utf-8")

        # Check if response_body is already a JSON string and parse it
        if isinstance(response_body, str):
            model_response = json.loads(response_body)
        else:
            model_response = response_body

        # Handle response format (assumed as list or dict)
        if isinstance(model_response, list):
            result = model_response[0].get("generated_text", "No response received.")
        elif isinstance(model_response, dict):
            result = model_response.get("generated_text", "No response received.")
        else:
            result = "Unexpected response format received."

        # Return the result to the client
        return jsonify({"response": result}), 200

    except Exception as e:
        return jsonify({"error": str(e)}), 500

I would check this out first: https://www.philschmid.de/sagemaker-huggingface-llm

And I would also recommend using the chat template (which isn't used in the above, but it makes calling LLMs much easier)

messages = [
            {"role": "system", "content": "You are Amastay, an AI concierge."},
            {
                "role": "user",
                "content": "My name is Brian. I am the user interacting with you.",
            },
            {"role": "user", "content": user_input},
        ]
inputs = tokenizer.apply_chat_template(messages, tokenize=False)

payload = {"inputs": inputs}

thanks @nbroad
i was actually able to figure out the messages thing by setting MESSAGES_API_ENABLED to true. Full script

from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
from dotenv import load_dotenv
import os

# Load environment variables from .env file
load_dotenv()

# Hugging Face Hub token from .env
hugging_face_hub_token = os.getenv("HUGGING_FACE_HUB_TOKEN")

# Ensure the token is set properly
if hugging_face_hub_token is None:
    raise ValueError(
        "You must provide a valid Hugging Face Hub token in the .env file."
    )

# Hub Model configuration
hub = {
    "HF_MODEL_ID": "meta-llama/Llama-3.2-3B-Instruct",
    "SM_NUM_GPUS": "1",  # Number of GPUs to use
    "HUGGING_FACE_HUB_TOKEN": hugging_face_hub_token,
    "MESSAGES_API_ENABLED": "true",
}

# Get the image URI for the model
image_uri = get_huggingface_llm_image_uri("huggingface", version="2.2.0")

# Use the correct SageMaker execution role ARN
role_arn = "xxx "

# Create Hugging Face Model Class
huggingface_model = HuggingFaceModel(image_uri=image_uri, env=hub, role=role_arn)

# Update instance type based on GPU requirements
predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.xlarge",
    container_startup_health_check_timeout=300,
)

# Save the deployed endpoint name for future use
endpoint_name = predictor.endpoint_name
print(f"Deployed endpoint: {endpoint_name}")

# Send request to the deployed model endpoint
response = predictor.predict(
    {
        "messages": [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": "What is deep learning?"},
        ]
    }
)

print(response)

# Clean up the endpoint when not needed (uncomment this line to delete the endpoint)
# predictor.delete_endpoint()

Sign up or log in to comment