Stable Diffusion with Amazon Bedrock Python boto3
Amazon Bedrock is a managed service provided by AWS that allows users to invoke models. For more information about the service, please refer to their user guide here.
For this article, we’re going to focus on how to invoke the Stable Diffusion model using the Python SDK: boto3
. If you want to learn about the Amazon Bedrock SDK in general or how to invoke Claude with it, I have an article here that goes into detail.
What is Stable Diffusion
Stable diffusion is a text-to-image model first released in 2022. The term “diffusion” originates from diffusion models in machine learning. OpenAI’s DALL-E 2 is another example of a model that leverages diffusion. It’s open source, free, and easy to run! In addition, it has an active community and a plethora of how-to tutorials.
For more information about Stable Diffusion, check out the AWS write up about it.
How use Stable Diffusion with Amazon Bedrock
For this article, we’ll be leveraging the Python AWS SDK: boto3
to call the Stable Diffusion model.
Install boto3
Let’s begin by installing the latest version of boto3
and the Python image library PIL
.
1pip3 install boto3 Pillow
Lookup the model inference parameters
Each model has specific inference parameters that must be supplied to Bedrock.
As of time of writing, Dec. 22nd 2023, the supported inference parameters for Stable Diffusion 1.0 text-to-image are:
1{
2 "text_prompts": [
3 {
4 "text": string,
5 "weight": float
6 }
7 ],
8 "height": int,
9 "width": int,
10 "cfg_scale": float,
11 "clip_guidance_preset": string,
12 "sampler": string,
13 "samples",
14 "seed": int,
15 "steps": int,
16 "style_preset": string,
17 "extras": JSON object
18}
The smallest payload we can send is this:
1{
2 "text_prompts": [{"text": string}]
3}
Use the invoke_model API
1import base64
2import json
3import io
4from PIL import Image
5
6import boto3
7
8client = boto3.client("bedrock-runtime", region_name="us-east-1")
9
10body = {"text_prompts": [{"text": "A blue bird"}]}
11
12response = client.invoke_model(
13 body=json.dumps(body), modelId="stability.stable-diffusion-xl"
14)
15
16response_body = json.loads(response["body"].read())
17finish_reason = response_body.get("artifacts")[0].get("finishReason")
18
19if finish_reason in ["ERROR", "CONTENT_FILTERED"]:
20 raise Exception(f"Image error: {finish_reason}")
21
22base64_image = response_body["artifacts"][0]["base64"]
23base64_bytes = base64_image.encode("ascii")
24image_bytes = base64.b64decode(base64_bytes)
25image = Image.open(io.BytesIO(image_bytes))
26image.show()
Here is the line-by-line breakdown of the code above:
- Instantiate the
boto3.client("bedrock-runtime")
client with aregion_name
ofus-east-1
. json.dumps
a dictionary of Stable Diffusion’s required inference parameters- Invoke the model by specifying the
body
(string) andmodelId
- If the
finishReason
isERROR
orCONTENT_FILTERED
, raise an error - Convert the
response_body
to the necessary types forPIL.Image
- Use
image.show()
to display a simple gui with the image
Output: