This text gives you the behind-the-scenes of how we made an efficient inference server that powers bloom.
inference server that powers https://huggingface.co/bigscience/bloom.
We achieved a 5x latency reduction over several weeks (and 50x more throughput). We desired to share all of the struggles and epic wins we went through to realize such speed improvements.
A whole lot of different people were involved at many stages so not the whole lot might be covered here. And please bear with us, among the content may be outdated or flat out flawed because
we’re still learning the best way to optimize extremely large models and a number of recent
hardware features and content keep coming out often.
In case your favorite flavor of optimizations
isn’t discussed or improperly represented, we’re sorry, please share it with us
we’re greater than blissful to check out recent stuff and proper our mistakes.
Creating BLOOM
This goes without saying but without the big model being accessible in the primary
place, there can be no real reasons to optimize inference for it. This was an
incredible effort led by many various people.
To maximise the GPU during training, several solutions were explored
and ultimately, Megatron-Deepspeed was chosen to coach the top model.
This meant that the code as-is wasn’t necessarily compatible with the transformers
library.
Porting to transformers
Due to the original training code, we got down to do something which we often
do: port an existing model to transformers. The goal was to extract from the
training code the relevant parts and implement it inside transformers.
This effort was tackled by Younes.
That is in no way a small effort because it took almost a month and 200 commits to get there.
There are several things to notice that can come back later:
We would have liked to have smaller models bigscience/bigscience-small-testing and bigscience/bloom-560m.
This is amazingly vital because they’re smaller, so the whole lot is quicker when
working with them.
First, you will have to desert all hope to have the exact same logits at the top down
to the bytes. PyTorch versions can change the kernels and introduce subtle differences, and different hardware
might yield different results because of various architecture (and you most likely
don’t desire to develop on a A100 GPU on a regular basis for cost reasons).
Getting strict test suite is absolutely vital for all models
The perfect test we found was having a hard and fast set of prompts. You understand the prompt,
you understand the completion that should be deterministic so greedy.
If two generations are similar, you’ll be able to mainly ignore small logits differences
Every time you see a drift, you want to investigate. It may very well be that your code
isn’t doing what it should OR that you just are literally out of domain for that model
and subsequently the model is more sensitive to noise. If you will have several prompts
and long enough prompts, you are less more likely to trigger that for all prompts by
accident. The more prompts the higher, the longer the higher.
The primary model (small-testing) is in bfloat16 like the large bloom so
the whole lot ought to be very similar, however it wasn’t trained loads or simply doesn’t perform
well, so it highly fluctuates in outputs. Which means we had issues with those generation
tests. The second model is more stable but was trained and saved in float16 as an alternative
of bfloat16. That is more room for error between the 2.
To be perfectly fair bfloat16 -> float16 conversion gave the impression to be OK in inference
mode (bfloat16 mostly exists to handle large gradients, which don’t exist in inference).
During that step, one vital tradeoff was discovered and implemented.
Because bloom was trained in a distributed setting, a part of the code was doing
Tensor parallelism on a Linear layer meaning running the identical operation as a single
operation on a single GPU was giving different results.
This took some time to pinpoint and either we went for 100% compliance and the model
was much slower, or we’d take a small difference in generation
but was much faster to run and simpler code. We opted for a configurable flag.
First inference (PP + Speed up)
Note: Pipeline Parallelism (PP) means on this context that every GPU will own
some layers so each GPU will work on a given chunk of knowledge before handing
it off to the following GPU.
Now we’ve got a workable transformers clean version of the beginning
working on running this.
Bloom is a 352GB (176B parameters in bf16) model, we want at the least that much
GPU RAM to make it fit. We briefly explored offloading to CPU on smaller machines
however the inference speed was orders of magnitude slower so we discarded it.
Then we desired to mainly use the pipeline.
So it’s dogfooding and that is what the API uses under the hood on a regular basis.
Nevertheless pipelines aren’t distributed aware (it is not their goal). After briefly
discussing options, we ended up using speed up newly
created device_map="auto" to administer the sharding of the model. We needed to iron
out just a few bugs, and fix the transformers code a bit to assist speed up do the suitable job.
It really works by splitting the varied layers of the transformers and giving a part of
the model to every GPU. So GPU0 gets to work, then hands it over to GPU1 so on
and so forth.
In the long run, with a small HTTP server on top, we could start serving bloom (the large model) !!
Start line
But we’ve not even began discussing optimizations yet!
We even have quite a bit, all this process is a castle of cards. During
optimizations we’re going to make modifications to the underlying code, being
extra sure you are not killing the model in by some means is absolutely vital
and easier to do than you think that.
So we are actually on the very first step of optimizations and we want to begin measuring
and keep measuring performance. So we want to contemplate what we care about.
For an open inference server supporting many options, we expect users to send
many queries with different parameters and what we care about are:
The variety of users we will serve at the identical time (throughput)
How long does it take for a mean user to be served (latency)?
We made a testing script in locust which is precisely this:
from locust import HttpUser, between, task
from random import randrange, random
class QuickstartUser(HttpUser):
wait_time = between(1, 5)
@task
def bloom_small(self):
sentence = "Translate to chinese. EN: I like soup. CN: "
self.client.post(
"/generate",
json={
"inputs": sentence[: randrange(1, len(sentence))],
"parameters": {"max_new_tokens": 20, "seed": random()},
},
)
@task
def bloom_small(self):
sentence = "Translate to chinese. EN: I like soup. CN: "
self.client.post(
"/generate",
json={
"inputs": sentence[: randrange(1, len(sentence))],
"parameters": {
"max_new_tokens": 20,
"do_sample": True,
"top_p": 0.9,
"seed": random(),
},
},
)
**Note: This isn’t the very best nor the one load testing we used, however it was
at all times the primary to be run in order that it could compare fairly across approaches.
Being the very best on this benchmark does NOT mean it’s the very best solution. Other
more complex scenarios needed to be used along with actual real-world performance.
**
We wanted to watch the ramp-up for various implementations and in addition be certain that
that underload the server properly circuit breaked. Circuit breaking means
that the server can answer (fast) that it is going to not answer your query because too
many persons are attempting to use it at the identical time.
It’s extremely vital to avoid the hug of death.
On this benchmark the initial performance was (on 16xA100 40Go on GCP which is the machine used throughout):
Requests/s : 0.3 (throughput)
Latency: 350ms/token (latency)
Those numbers aren’t that great. Before attending to work let’s estimate
the very best we will imagine achieving.
The formula for amount of operations is 24Bsh^2 + 4𝐵s^2h24Bsh^2 + 4𝐵s^2h where B is
the batch size, s the sequence length, and h the hidden dimension.
Let’s do the mathematics and we’re getting 17 TFlop for a single forward pass.
Taking a look at the specs of A100 it claims 312 TFLOPS for a single card.
Which means a single GPU could potentially run at 17 / 312 = 54ms/token. We’re using 16 of those so 3ms/token on the general
machine. Take all these numbers with a giant grain of salt, it’s never possible to succeed in those numbers,
and real-life performance rarely matches the specs. Also if computation isn’t your limiting
factor then this isn’t the bottom you’ll be able to get. It’s just good practice to understand how far you might be from
your goal. On this case, we’re 2 orders of magnitude so pretty far. Also, this estimate puts
all of the flops on the service of latency which suggests only a single request can go at a time (it’s okay because you’re maximizing your machine
so there’s not much else to be done, but we will have higher latency and get throughput back through batching far more easily).
Exploring many routes
Note: Tensor Parallelism (TP) means on this context that every GPU will own
a part of the weights, so ALL gpus are lively on a regular basis and do less work.
Often this comes with a really slight overhead that some work is duplicated
and more importantly that the GPUs often should communicate to one another
their results to proceed the computation
Now that we’ve got understanding of where we stand it is time to get to work.
We tried many various things based on the people and our various knowledge.
ALL endeavors deserve their very own blog post so I’ll just list them, explain the
few final learnings and delve into the small print of only what went into the present
server. Moving from Pipeline Parallelism (PP) to Tensor Parallelism (TP) is
one big interesting change for latency. Each GPU will own a part of the parameters
and all might be working at the identical time. So the latency should decrease drastically
but the worth to pay is the communication overhead since they often need
to speak with one another about their results.
It’s to notice that this can be a very wide selection of approaches and the intent
was deliberately to learn more about each tool and the way it could slot in later
endeavors.
Porting the code the JAX/Flax to run on TPUs:
- Expected to be easier to decide on the variety of parallelism. so TP ought to be
easier to check.
It’s one in all the perks of Jax’s design. - More constrained on hardware, performance on TPU likely superior
than GPU, and fewer vendor selection for TPU. - Cons, one other port is required. But it surely can be welcome anyway in our libs.
Results:
- Porting was not a straightforward task as some conditions and kernels were hard to
reproduce accurately enough. Still manageable though. - Parallelism was quite easy to get once ported
Kudos to Jax the claim is alive. - Ray/communicating with TPU employees proved to be an actual pain for us.
We do not know if its the tool, the network, or just our lack of information
however it slowed down experiments and work far more than we anticipated.
We’d launch an experiment that takes 5mn to run, wait for 5mn nothing
had happened, 10mn later still nothing, turned out some employee was down/not responding
we needed to manually get in, determine what went on, fix it, restart something, and relaunch and we had just lost half an hour.
Repeat that enough times, and lost days add up quickly.
Let’s emphasize that it is not necessarily a critique of the tools we used
however the subjective experience we had stays. - No control over compilation
Once we had the thing running, we tried several settings to determine which
suited best the inference we had in mind, and it turned out it was really hard
to guess from settings what would occur within the latency/throughput. As an example,
we had a 0.3 rps on batch_size=1 (so every request/user is by itself) with a latency of
15ms/token (Don’t compare an excessive amount of with other numbers in this text it’s on a special machine with
a really different profile) which is great, but the general throughput isn’t a lot better than
what we had with the old code. So we decided so as to add batching, and with BS=2 and the
latency went up 5 fold, with only 2 times the throughput… Upon further investigation,
it turned out that as much as batch_size=16 every batch_size had the identical latency profile.
So we could have 16x more throughput at a 5x latency cost. Not bad, but looking
on the numbers we actually would have preferred a more fine-grained control.
The numbers we were aiming for stem from the 100ms, 1s, 10s, 1mn rule.
Using ONNX/TRT or other compiled approaches
- They’re speculated to handle many of the optimization work
- Con, Often parallelism must be handled manually.
Results:
- Turned out that to give you the option to trace/jit/export stuff we wanted to
rework a part of the PyTorch, so it easily fused with the pure PyTorch approach
And overall we discovered that we could have many of the optimizations we desired
by staying inside PyTorch world, enabling us to maintain flexibility without
having to make an excessive amount of coding effort.
One other thing to notice, since we’re running on GPU and text-generation has many
forward passes occurring, we want the tensors to remain on the GPU, and it’s
sometimes hard to send your tensors to some lib, be given back the result, perform
the logits computation (like argmax or sampling) and feed it back again.
Putting the loop throughout the external lib means losing flexibility identical to
Jax, so it was not envisioned in our use case.
DeepSpeed
- That is the technology that powered training, it seemed only fair to make use of
it for inference - Cons, it was never used/prepared for inference before.
Results:
- We had really impressive results fast that are roughly the identical as
the last iteration we’re currently running. - We needed to invent a method to put a webserver (so coping with concurrency) on
top of DeepSpeed which also has several processes (one for every GPU). Since
there is a wonderful library Mii.
It doesn’t fit the extremely flexible goals we had in mind, but we probably
would have began working on top of it now. (The present solution is discussed later). - The largest caveat we encountered with DeepSpeed, was the shortage of stability.
We had issues when running it on CUDA 11.4 where the code was built for 11.6
And the long-standing issue we could never really fix is that there would
be regular kernel crashes (Cuda illegal access, dimensions mismatch, etc..).
We fixed a bunch of those but we could never quite achieve stability under stress
of our webserver. Despite, that I would like to shout out to the Microsoft folks that
helped us, we had a extremely good conversation that improved our understanding
of what was happening, and gave us real insights to do some follow-up works. - One among the pain points I feel is that our team is usually in Europe, while
Microsoft is in California, so the collaboration was tricky timewise and we
probably lost a giant chunk of time due to it. This has nothing to do
with the technical part, however it’s good to acknowledge that the organizational
a part of working together can also be really vital. - One other thing to notice, is that DeepSpeed relies on
transformersto inject
its optimization, and since we were updating our code just about consistently
it made it hard for the DeepSpeed team to maintain things working on ourpredominant
branch. We’re sorry to have made it hard, I assume because of this it’s called
bleeding edge.
Webserver ideas
- Provided that we’re going to run a free server where users are going to
send long text, short text, want just a few tokens, or an entire recipe each with
different parameters, something needed to be done here.
Results:
- We recoded the whole lot in
Rustwith the superb bindings tch-rs. Rust was not aimed toward having performance gains but just
far more fine-grained control over parallelism (threads/processes) and playing
more fine-grained on the webserver concurrency and the PyTorch one.
Python is infamously hard to handle low-level details due to the GIL. - Turned out that the majority of the pain got here from the port, and after that, the experimentation
was a breeze. And we figured that with enough control over the loops
we could have great performance for everybody even within the context of a really
big range of requests with different properties. Code for the curious, however it doesn’t include any support or nice docs. - It became production for just a few weeks since it was more lenient on the parallelism, we could use the GPUs more efficiently (using GPU0 for request 1
while GPU1 is treating request 0).
and we
went from 0.3 RPS to ~2.5 RPS with the identical latency. The optimal case would have been to extend throughput by 16X however the numbers shown here
are real workloads measurements so this isn’t too bad.
Pure PyTorch
- Purely modify the present code to make it faster by removing operations
likereshape, using better-optimized kernels so on and so forth. - Con, we’ve got to code TP ourselves and we’ve got a constraint that the code still matches our library (mostly).
Results
Final route: PyTorch + TP + 1 custom kernel + torch.jit.script
Writing more efficient PyTorch
The primary item on the list was removing unnecessary operations in the primary implementations
Some may be seen by just taking a look at the code and determining obvious flaws:
- Alibi is utilized in Bloom so as to add position embeddings and it was calculated in too
many places, we could only calculate it once and more efficiently.
The old code: link
The brand new code: link
This can be a 10x speedup and the newest version includes padding too!
Since this step is simply computed once, the actual speed isn’t vital
but overall reducing the variety of operations and tensor creation is direction.
Other parts come out more clearly whenever you start profiling and we used quite extensively the tensorboard extension
This provides this form of image which give insights:

Attention takes a number of time, careful this can be a CPU view so the long
bars don’t mean long, they mean the CPU is awaiting the GPU results of the
previous step.

We see many `cat` operations before `baddbmm`.
Removing a number of reshape/transpose, for example, we discovered that:
– The eye is the new path (it’s expected but at all times good to confirm).
– In the eye, a number of kernels were actual copies on account of the large amount of reshapes
– We could remove the reshapes by reworking the weights themselves and the past.
This can be a breaking change however it did improve performance quite a bit!
Supporting TP
Okay, we’ve got removed many of the low-hanging fruits now we went roughly from 350ms/token
latency to 300ms/token in PP. That is a 15% reduction in latency, however it actually provided
greater than that, but we weren’t extremely rigorous in our measuring initially so let’s stick with that figure.
Then we went on to supply a TP implementation. Turned out to be much faster
than we anticipated the implementation took half a day of a single (experienced) dev.
The result’s here. We were also capable of reuse code from other projects which helped.
The latency went directly from 300ms/token to 91ms/token which is a large improvement in user experience.
An easy 20 tokens request went from 6s to 2s which went from a “slow” experience to barely delayed.
Also, the throughput went up loads to 10RPS. The throughput comes from the very fact
that running a question in batch_size=1 takes the identical time as batch_size=32
and throughput becomes essentially free in latency cost at this point.
Low-hanging fruits
Now that we had a TP implementation, we could start profiling and optimizing again.
It’s a big enough shift that we had to begin from scratch again.
The very first thing that stood out, is that synchronization (ncclAllReduce) starts
to develop into a preponderant a part of the load, which is predicted, that is the synchronization
part and it is taking a while. We never tried to look and optimize this because it’s
already using nccl but there might still be some room for improvement there.
We assumed it could be hard to do a lot better.
The second thing is that Gelu operator was launching many elementwise
kernels and overall it was taking an even bigger share of compute than we expected.
We made the change from:
def bloom_gelu_forward(x):
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
to
@torch.jit.script
def bloom_gelu_forward(x):
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
This transforms the operations from multiple small element-wise kernels (and hence tensor copies)
to a single kernel operation!
This provided a ten% latency improvement from 91ms/token to 81ms/token, right there!
Watch out though, this isn’t some magic black box you’ll be able to just throw all over the place,
the kernel fusion is not going to necessarily occur or the previously used operations
are already extremely efficient.
Places where we found it worked well:
- You’ve gotten a number of small/elementwise operations
- You’ve gotten a hotspot with just a few hard-to-remove reshape, copies basically
- When the fusion happens.
Epic fail
We also had some points, during our testing periods, where we ended up seeing some consistent
25% lower latency for the Rust server in comparison with the Python one. This was fairly
odd, but since it was consistently measured, and since removing kernels provided a speed
up, we were under the impression that perhaps dropping the Python overhead could
provide a pleasant boost.
We began a 3-day job to reimplement the essential parts of torch.distributed
To rise up and running within the Rust world nccl-rs.
We had the version working but something was off within the generations in comparison with its
Python counterpart. Through the investigation of the problems, we figured…
that we had forgotten to remove the profiler within the Pytorch measurements…
That was the epic fail because removing it gave us back the 25% after which each
codes ran just as fast. That is what we initially expected, that python mustn’t
be a performance hit, because it’s mostly running torch cpp’s code. In the long run,
3 days isn’t the top of the world, and it’d develop into useful sometime within the
future but still pretty bad.
This is sort of common when doing optimizations to do flawed or misrepresentative
measurements which find yourself being disappointing and even detrimental to the general
product. Because of this doing it in small steps and having expectations concerning the
final result as soon as possible helps contain that risk.
One other place where we needed to be extra careful, was the initial forward pass (without
past) and the later forward passes (with past). If you happen to optimize the primary one,
you are most actually going to be slowing down the later ones that are far more
vital and account for many of the runtime.
One other pretty common wrongdoer is measuring times that are CPU times, and never
actual CUDA times, so you want to torch.cuda.synchronize() when doing
runs to make sure that the kernels complete.
Custom kernel
To this point, we had achieved near DeepSpeed performance with none custom code
outside of PyTorch! Pretty neat. We also didn’t should make any compromise
on the flexibleness of the run time batch size!
But given the DeepSpeed experience, we desired to try to write a custom kernel
to fuse just a few operations in the new path where torch.jit.script wasn’t capable of
do it for us. Essentially the next two lines:
attn_weights = attention_scores.masked_fill_(attention_mask, torch.finfo(attention_scores.dtype).min)
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
The primary masked fill is making a recent tensor, which is here only to
say to the softmax operator to disregard those values. Also, the softmax must be calculated
on float32 (for stability) but inside a custom kernel, we could limit the quantity of
upcasting essential so we limit them to the actual sums and collected needed.
Code may be found here.
Take note we had a single GPU architecture to focus on so we could give attention to this
and we aren’t experts (yet) at writing kernels, so there may very well be higher ways
to do that.
This tradition kernel provided one more 10% latency increase moving down from
81ms/token to 71ms/token latency. All of the while keeping our flexibility.
After that, we investigated and explored other things like fusing more operators
removing other reshapes, or putting them elsewhere. But no attempt ever made
a big enough impact to make it to the ultimate versions.
Webserver part
Identical to the Rust counterpart, we needed to implement the batching of requests
with different parameters. Since we were within the PyTorch world, we’ve got pretty
much full control of what is going on on.
Since we’re in Python, we’ve got the limiting factor that the torch.distributed
must run on several processes as an alternative of threads, which suggests it’s barely
harder to speak between processes. In the long run, we opted to speak
raw strings over a Redis pub/sub to distribute the requests to all processes directly.
Since we’re in numerous processes it’s easier to do it that way than communicating
tensors (that are way greater) for example.
Then we needed to drop the use generate since
this is applicable the parameters to all members of the batch, and we actually
wish to apply a special set of parameters.
Thankfully, we will reuse lower-level items just like the LogitsProcessor
to save lots of us a number of work.
So we reconstructed a generate function that takes an inventory of parameters
and applies them to every member of the batch.
One other really vital aspect of the ultimate UX is latency.
Since we’ve got different parameter sets for various requests, we might need
1 request for 20 tokens and the opposite for 250 tokens. Because it takes
75ms/token latency one request takes 1.5s and the opposite 18s. If we were
batching all the way in which, we can be making the user that asked to attend for 18s
and making it appear to him as if we were running at 900ms/token which is sort of slow!
Since we’re in a PyTorch world with extreme flexibility, what we will do as an alternative
is extract from the batch the primary request as soon as we generated to first 20
tokens, and return to that user throughout the requested 1.5s! We also occur to save lots of 230 tokens value of computation.
So flexibility is vital to get the very best possible latency on the market.
Last notes and crazy ideas
Optimization is a never-ending job, and like every other project, 20% of labor
will often yield 80% of the outcomes.
Sooner or later, we began having a small testing technique to determine
potential yields of some idea we had, and if the tests didn’t yield significant
results then we discarded the thought. 1 day for a ten% increase is priceless enough, 2 weeks for 10X
is priceless enough. 2 weeks for 10% isn’t so interesting.
Have you ever tried …?
Stuff we all know exists and have not used because of varied reasons. It
may very well be it felt prefer it wasn’t adapted to our use case, it was an excessive amount of
work, the yields weren’t promising enough, and even simply we had too many
options to check out from and discarded some for no particular reasons and just
lack of time. The next are in no particular order:
Please be at liberty to succeed in out in case your favorite tool is missing from
here or in the event you think we missed out on something vital that would
prove useful!
Flash attention
We have now briefly checked out integrating flash attention, and while it performs extremely
well on the primary forward pass (without past_key_values) it didn’t yield as big improvements
when running when using past_key_values. Since we wanted to adapt it to incorporate the alibi tensor
within the calculation we resolve to not do the work (at the least not yet).
OpenAI Triton
Triton is an incredible framework for constructing custom kernels
in Python. We wish to get to make use of it more but we’ve not up to now. We’d
be desirous to see if it performs higher than our Cuda kernel. Writing directly in
Cuda gave the look of the shortest path for our goal once we considered our options
for that part.
Padding and Reshapes
As mentioned throughout this text, every tensor copy has a value and one other
hidden cost of running production is padding. When two queries are available in with very
different lengths, you will have to pad (use a dummy token) to make them fit a square.
This results in possibly a number of unnecessary calculations. More information.
Ideally, we’d give you the option to not do those calculations in any respect, and never have reshapes.
Tensorflow has the concept of RaggedTensor and
Pytorch Nested tensors. Each of those
seem not as streamlined as regular tensors but might enable us to do less computation
which is at all times a win.
In a great world, the complete inference can be written in CUDA or pure GPU implementation.
Considering the performance improvements yielded once we could fuse operations it looks desirable.
But to what extent this is able to deliver, we’ve got no idea. If smarter GPU people have
ideas we’re listening!
Acknowledgments
All this work results of the collaboration of many HF team members. In no particular
order, @ThomasWang @stas
@Nouamane @Suraj
@Sanchit @Patrick
@Younes @Sylvain
@Jeff (Microsoft) @Reza
And all of the BigScience organization.
