Free e-Book:The Modern Data Stack:A Technical Roadmap.Download for free now!
Universe Scale Serving of Generative Models

Universe Scale Serving of Generative Models

And Why You Might Not Want To Use FastAPI For That Purpose
Javier Mermet

Posted by Javier Mermet

on April 12, 2023 · 9 mins read

New frontiers

If you're a fan of our blog, you know we're constantly pushing ourselves to be at the frontier of AI/ML tech. In our latest technical post, we introduced Stable Diffusion. Not because of all the hype around it, but because our brand-new Research Squad has been doing ✨ magic ✨ with it!

But building on top of state of the art models is not enough by itself if you plan on using it in production. Very much alike other ML projects, for this endeavor to thrive, we had to build the capabilities for serving, monitoring and scaling the infrastructure that will serve as the backbone of the model. This was new ground for us, so we took it one step at a time. We started building from the ground up, learning along the way different alternatives towards getting the most out of available hardware in serving GPU intensive models such as Stable Diffusion.

Goals of our tests

Before doing any benchmarks, we need to know what we'll be optimizing for. In serving a model such as Stable Diffusion, the GPU is most likely to be your hardware bottleneck, whether you are going with cloud providers or on-premise. So, constrained to a given GPU, we want to get the highest throughput possible. That's to say: the most images per second or the most requests per second. These two are similar premises but have a nuanced difference: a request might demand more than one image. So, let's focus on RPS under the assumption that each request requires to generate just one image.

Slow start: FastAPI

FastAPI is a wrapper on top of Starlette, which has gained popularity over the last few years in the Python ecosystem. Many tutorials for Machine Learning practitioners teach you how to serve your models through a RESTful interface using FastAPI, but most lack a complete outlook into the MLOps aspect of it.

FastAPI falls short in many aspects for our purposes:

  • Performance
    • Ranking: This benchmark is quite popular, and you might notice something not so curious. Python frameworks do not rank well. So, already, the language of choice is not optimal.
    • (Micro) Batching: Since FastAPI is not built with the explicit purpose of serving ML models, microbatching is not a built-in feature. You will have to either build it yourself or use a third party package.
  • Observability: FastAPI is just a web framework. As such, it doesn't bundle features that enable MLOps/Devops teams to monitor the near real-time status of an app.
  • ASGI goes only as far as your GPU/s: The hardware bottleneck in your system will most likely be the GPU, so you want to make the most out of it. There are some optimizations to be made in the API itself, but you'll have to consider how to use the GPU to it's fullest.

Addressing performance

FastAPI optimizations

  1. Don't use StreamingResponse. Seriously. Most online tutorials return the generated images using StreamingResponse, but there are several reasons not to. Using a plain Response, we saw improvements ranging from 5% for the single user scenario and up to 60% on median latency for the multi user scenario.
  2. Use microbatching. We had to implement this ourselves, but the idea is that you can generate several images with different prompts all at once (provided some other parameters are the same) and we make better use of the GPU. Generation is blocking, so we can serve more requests at once.
    • We set two parameters here: the maximum time to wait before fetching the next batch and the maximum number of prompts to fetch. 500ms and 8 prompts seemed like sensible defaults to start with.

Uvicorn workers

  1. Use uvloop. By adding uvicorn as a dependency with the standard extras, you get uvloop, which increases performance.

PyTorch optimizations

  1. Enable attention slicing for VRAM savings
  2. Enable memory efficient attention for performance
  3. Enable cuDNN auto-tuner
  4. Disable gradients calculation
  5. Set PyTorch's num_threads to 1

Implementing observability

Out of the box, FastAPI has no observability capabilities. But there are lots of packages that add features and can be seamlessly integrated.

Prometheus

Prometheus is an open-source systems monitoring and alerting toolkit originally built at SoundCloud. Prometheus collects and stores its metrics as time series data, i.e. metrics information is stored with the timestamp at which it was recorded, alongside optional key-value pairs called labels.

It is widely used to collect metrics as time series and query them through PromQL.

We will be using cadvisor to export metrics from our containers into prometheus, dcgm-exporter to export GPU metrics and as we will see later, metrics straight from FastAPI. For each of these, we use the default scrape_interval of 15 seconds.

Grafana

We set up two dashboards, for different purposes: monitoring the GPU and the API container.

api dashboard
The dashboard for monitoring the API container.
gpu dashboard
The dashboard for monitoring GPU usage.

These are a must have for monitoring in a production environment. You need to be able to tell what's wrong with a quick glance.

FastAPI Prometheus instrumentator

We used prometheus-fastapi-instrumentator to add a /metrics endpoint which exports metrics in a prometheus-friendly format.

We had to add latency and requests metrics, but it was a breeze. Read the docs and you will find most use cases explained.

Stress testing

We evaluated several options to stress test the API under different scenarios and workloads. We needed to easily set scenarios to each endpoint so that we could compare later.

  • Artillery. Great performance, configurable with code.
  • Apache Bench. Easy setup, wrappable in scripts for different scenarios.
  • Apache JMeter. A classic choice.
  • Locust. For the pythonista. Easy setup. Configurable with code.

Although it's better suited for more complex workflows than ours, we went with locust because of ease of setup, scenario writing, we get performance measures out of the box (on top of prometheus) and we've used it in the past for several projects.

Locust

The setup for each scenario looked similar to this snippet:

import random from locust import FastHttpUser, task class T2IHighRes(FastHttpUser): @task def fetch_image(self): self.client.post( "/t2i", json={ "seed": random.randint(0, 1_000_000), "negative_prompt": "photography", "prompt": "digital painting of bronze (metal raven), automaton, all metal, glowing magical eye, intricate details, perched on workshop bench, cinematic lighting, by greg rutkowski, in the style of midjourney, ({steampunk})", "width": 512, "height": 512, "inference_steps": 20, "guidance_scale": 8.2, # "n": 1, }, )

(Which produces stunning images, by the way)

locust 0

You'll notice that we get many metrics that we already had in grafana. We used both as a double check, but with grafana/prometheus you can get more stylized metrics, such as 5 minutes windows moving averages.

Experiments & Results

We conducted all experiments on g4dn.xlarge spot instances provided by AWS, using the Amazon Linux 2 AMI with Nvidia drivers. Initially, we tried deploying on vast.ai, but we found that it was not possible to run Docker containers due to environment limitations (docker in docker is not available). So, it was not suited for our choice of deployment.

g4dn.xlarge instances have 16GB of RAM, 4 vCPUs and a Nvidia Tesla T4 GPU.

There are five endpoints to generate images, four of them are based on combining these traits:

  • Sync vs Async: When using FastAPI, you usually define endpoints as async functions, but many times, the underlying service is not async (for instance, sqlalchemy querying). We tested whether defining the endpoint+service as async functions made any difference.
  • Stream response vs Response: As stated before, we felt this was necessary to test due to how ubiquitous using a StreamingResponse to return images is on online tutorials and reference implementations. We implemented the same endpoint twice, once using a plain Response object and another using StreamingResponse

And the fifth one is the microbatching endpoint, which uses a custom priority queue to serve a configurable amount of responses at the same time. As said before, the batch timeout are set to 500ms by default and the max batchsize is set to 8. This means that every 500ms, if no group of requests has accumulated 8 requests, we fetch the group with the oldest request and serve those first. Requests are grouped by width, height, inference_steps, and guidance_scale.

For all scenarios we got 0 failures, thus they are not reported. Metrics reported here come from Locust, but where revised against the aforementioned dashboards.

Scenario 1

With only 1 user, for 15 minutes.

{: .table .table-striped .margin-left:auto .margin-right:auto .table-hover .text-nowrap}

ImplementationRequestsAvg (ms)RPSp50 (ms)p90 (ms)p95 (ms)p99 (ms)
Streaming (Async)21841120.24100420042004200
Response (Async)23039120.33900400040004000
Streaming (Sync)22639780.34000400040004200
Response (Sync)23638110.33800400040004000
Batched (Async)36824390.42400240024002400

Even without the bolding, you can see there's a clear winner. The microbatching implementation, albeit a very simple one, greatly improves both latency and throughput. And this is even without trying to optimize the microbatching configuration from the "sensible defaults" mentioned before.

Another clear insight for this use case is that using StreamingResponse harms performance. Async/Sync implementation of endpoints and generation functions make no difference.

But how does each implementation fare respecting to GPU usage? Let's check some metrics obtained from the Grafana dashboard.

{: .table .table-striped .margin-left:auto .margin-right:auto .table-hover .text-nowrap}

ImplementationGPU Avg VRAM UseGPU Avg Utilization
Streaming (Async)80.6%68.8%
Response (Async)60.8%59.3%
Streaming (Sync)59.7%71.0%
Response (Sync)58.5%63.1%
Batched (Async)75.1%42.8%

While one could make a point that using the most GPU is better, measuring that becomes tricky. For instance, the Sync Streaming endpoint uses more than the batched one, but results in lower throughput and higher latency.

Scenario 2

One could argue the above scenario is not quite representative of the workload a production-ready API might encounter. So, we built another scenario. 8 users, with a spawning rate of 0.02 users/second (a new user arrives every 50 seconds), for 15 minutes.

{: .table .table-striped .margin-left:auto .margin-right:auto .table-hover .text-nowrap}

ImplementationRequestsAvg (ms)RPSp50 (ms)p90 (ms)p95 (ms)p99 (ms)
Streaming (Async)261219510.326000280003400046000
Response (Async)269208750.39900530005600069000
Streaming (Sync)264216180.326000320003400046000
Response (Sync)270207290.39800560005900065000
Batched (Async)97559231.164006500650012000

This scenario gives further evidence of our prior conclusions.

{: .table .table-striped .margin-left:auto .margin-right:auto .table-hover .text-nowrap}

ImplementationGPU Avg VRAM UseGPU Avg Utilization
Streaming (Async)87.7%92.1%
Response (Async)84.2%91.9%
Streaming (Sync)90.6%93.6%
Response (Sync)60.2%87.7%
Batched (Async)71.3%80.3%

The batched implementation has no bottleneck in the GPU, which is a good thing: we should be able to keep scaling to more users.

Transformers optimizations.

In these tests, we tested the impact that had enabling and disabling Memory Efficient Attention (MEA) and Attention Slicing (AS). We used the same ramp up as in scenario 2: 8 users, with a spawning rate of 0.02 users/second, for 15 minutes.

{: .table .table-striped .margin-left:auto .margin-right:auto .table-hover .text-nowrap}

MEAASRequestsAvg (ms)RPSp50 (ms)p90 (ms)p95 (ms)p99 (ms)
97559231.164006500650012000
9326200169007000700011000
95760401.165006700680013000
113151021.35300580058008800

And as for GPU impact:

{: .table .table-striped .margin-left:auto .margin-right:auto .table-hover .text-nowrap}

MEAASGPU Avg. VRAMGPU Avg. Utilization
71.3%80.3%
72.7%87.4%
63.9%82.5%
49.9%83.8%

In our tests, enabling memory efficient attention and disabling attention slicing had the best results. Which is why we used that config in the next test.

Additionally, we tried using Channels Last memory format for NCHW tensors, and improved even further these results:

{: .table .table-striped .margin-left:auto .margin-right:auto .table-hover .text-nowrap}

RequestsAvg (ms)RPSp50 (ms)p90 (ms)p95 (ms)p99 (ms)
131643831.54800510053005400

Batched parameters optimization

By now, we have full evidence that even the most basic of microbatching implementations can greatly improve our target metrics. Can we get better numbers by playing around with some parameters? We tried increasing/decreasing both the batch timeout and batch max size. Same as previous test, we set 8 users to arrive at a rate of 0.02 users/second, and ran the test for 15 minutes.

{: .table .table-striped .margin-left:auto .margin-right:auto .table-hover .text-nowrap}

Batch Timeout (ms)Batch SizeRequestsAvg (ms)RPSp50 (ms)p90 (ms)p95 (ms)p99 (ms)
5008113151021.35300580058008800
2508112151451.25400600060009700
10008103055991.159006000600012000
5004108353311.25900590059008200
50016105155061.260006100610012000

Our takeaway is that those initial "sensible defaults" were... sensible enough. However, were we to take these results at face value, we'd really be overfitting to our test scenario. While changing the batch size can make the median latency 5-10% worse, that might change if we change the number of users.

Conclusions and next steps

While the initial results are exciting, they came at a cost. Mostly in engineering time and implementation of basic features for Deep Learning models serving.

We will be taking Triton for a test ride next, and we have high hopes for it. There are several other alternatives to consider, as this is a growing space. After everything you've read in this article, we have a strong baseline to compare against and a clearly defined methodology and performance indicators that will guide us.