You are on page 1of 11

Low Numerical Precision

Efficient deep learning in PyTorch with lower precision datatypes


Most DL models are single-
precision floats (FP32) by
default

Moving to lower numerical precision - while


reasonably maintaining accuracy - reduces:

model size
memory requirements
power consumption
⏩ Lower numerical
precision is faster

Lower numerical precision speeds up:

compute-bound operations, by reducing load on


the hardware
memory bandwidth-bound operations, by
accessing smaller data (weights, inputs)

In many deep models, memory access dominates


power consumption; reducing memory I/O makes
models more energy efficient.
Lower precision datatypes in
PyTorch
These datatypes are typically used in PyTorch:

FP16 or half-precision (torch.float16) -- only


supported on CUDA
BF16 (torch.bfloat16) -- supported on TPUs and
newer CPUs
INT8 (torch.quint8 and torch.qint8) which stores
floats in a quantized format
⚠️ model.half() ⚠️

Calling .half() on your network and tensors explicitly


casts them to FP16

But not all operations in PyTorch are safe to run in


half-precision; some ops require the full dynamic
range of FP32 or even FP64

A better solution is to use Automatic Mixed Precision


to let PyTorch choose the correct op-specific dtype
(FP32 vs FP16/BF16) for your tensors.
✅ Automatic Mixed Precision

For torch <= 1.9.1, AMP was limited to CUDA tensors using
torch.cuda.amp.autocast()
.
v1.10 onwards, PyTorch
has a generic API
torch.autocast() that
automatically casts

CUDA tensors to
FP16, and
CPU tensors to BF16
⏩ AMP is usually faster
than .half()
Model: Resnet101

Device: Tesla T4
GPU
⚠️ AMP is only for the
forward pass
Don’t wrap the backward pass in autocast()

The backward ops run in the same dtype as the


corresponding forward op was autocast to.
✅ Save network bandwidth in
distributed training
You can enable gradient compression to FP16 with
DistributedDataParallel (DDP)

https://pytorch.org/docs/stable/ddp_comm_hooks.ht
ml#default-communication-hooks
❓ What about non-BF16 and ARM
CPUs?

✅ Lower precision is currently enabled via


Quantization

Quantization converts FP32 to INT8, with a potential


4x reduction in model sizes.

Currently only the forward pass is quantizable, so you


can use this only for inference, not training.
Learn more

Half Precision:
https://pytorch-dev-
podcast.simplecast.com/episodes/half-precision

torch.autocast:
https://pytorch.org/docs/1.10./amp.html#id4

AMP Examples:
https://pytorch.org/docs/stable/notes/amp_examples
.html

Quantization in PyTorch:
https://pytorch.org/docs/stable/quantization.html

You might also like