Neural Network Architecture
The architecture of a neural network consists of a complete mathematical description of the series of transformations that data undergoes as it passes through the neural network. The architecture determines what a model is able to learn or represent, and it therefore has a direct impact on the resulting accuracy of a model trained on any specific dataset. Unfortunately, it is (currently) impossible to describe what any particular architecture can learn or to precisely determine the representational needs of most datasets (especially image- or text-based datasets). Consequently, the process of finding an appropriate architecture ultimately requires experimentation.
The literature is full of strategies for finding or designing a suitable neural network architecture for any given dataset—broadly called Neural Architecture Search (NAS). These strategies are varyingly complex, varyingly parallelizable, varyingly hardware or time intensive, and varyingly successful. In this section, we aim to provide a strategy that anyone can follow and, with some luck, find a sufficiently good architecture that can be trained and become a suitable model.
A Simple Architecture Search Strategy
The following is a practitioner's guide to finding a suitable architecture. Before beginning, you should have a good idea of what suitable means for your application. Generally, suitable means satisfying some constraints: What is the minimum accuracy required to be useful? Or, how fast does the model need to be on my hardware? Each experiment (training run) you do with an architecture will provide you with data points that inform your future experiments. Conveniently, Chariot will track all your training runs so you can come back at any time to runs you have already completed to review or compare.
Step 1: Just pick an architecture to start.
Pick a starting architecture, and try training the model.
Chariot provides good defaults for a first choice of models.
For classification tasks, the default choice is ResNet18
, which is the smallest model in the ResNet Family.
It is generally recommended to start small as the training will be faster and will allow you to get your first data point(s) quicker.
Step 2: Evaluate the results.
At each step, it is important to evaluate the suitability of the model and use that information to guide further iterations.
Training run metrics are computed and stored periodically during the run—users can specify the frequency at which these metrics are produced.
Metrics are available live during the run (for monitoring) on the Training Run > Metrics
page.
After a successful training run and evaluation of the model, you will most likely be in one of the following scenarios (determine which, and follow the recommendation).
Scenario 1: Suitable
In this scenario, the model meets or exceeds all your suitability requirements.
You can stop here and just take the model you have now trained.
Or, you could try to find a better model.
If better means faster, you could experiment with architectures whose relative inference time is lower than this model—generally, either smaller models in the same family or trying another family of models.
If better means more accurate, you could try training the model longer, adding data (and training longer), or training a larger model.
If you choose to train a larger model, try to stay in the same family and incrementally increase the size/depth of the model,
e.g., substitute a ResNet34
model for a ResNet18
model (instead of jumping from ResNet18
all the way to a ResNet152
model).
Then, train and reevaluate.
Scenario 2: Underfit
In this case, your model failed to learn at all, or it simply didn't learn enough (the accuracy is too low). There are several options to consider, any of which may produce a better model. Retry the training with the same architecture but different optimizer settings.
You might adjust the learning rate. If your batch size was smaller than the default (32), you probably want to decrease the default learning rate (from 0.001 to maybe 0.0005). Likewise, if the batch size is very large (say, for example, larger than 256 or 512), then you may need to increase the learning rate from the default. If the training metrics look very noisy (frequent large variation in the training loss), this can indicate that your model is being adjusted too rapidly. In that case, reducing the learning rate is also likely to be helpful. If the metrics are nearly constant, then that could indicate that your model is being adjusted too slowly. In that case, increasing the learning rate may be helpful.
If adjustments to the learning rate don't yield better results, then consider other optimizer settings as well. You might reduce the weight decay (if you had any), or try a different optimizer all together (Adam is a good default, but in some situations others may work better). You might even retry the training run (briefly) with identical settings—especially if this training run was "from scratch" and not starting from any pre-trained weights. If you've tried some of the above with no success, then try a larger model (either more parameters or more layers). Best practices are to incrementally increase the size of the model inside the same family before experimenting with models from other families. If your model is also too slow for your suitability requirements, then getting a larger model from the same family is not advised (as this will be even slower).
Scenario 3: Overfit
In this case, your model performs well (very small loss) on the training data but poorly on the test/validation data. If at some point during the training process the model was good on both the training and test/validation data, then you could simply use the model from an earlier checkpoint (less training, before it overfit) if that accuracy is suitable. In that case, the earlier checkpoint puts you into scenario 1. If the validation accuarcy was never good, then your model overfit to the training data too fast. The two best options to consider here are either using a smaller architecture (fewer parameters or fewer layers)—again first from the same family but alternatively from another family—or increasing the weight decay in the training run settings. If the model is also too slow, then opt first for a smaller architecture over increasing the weight decay as the latter will have no noticeable impact on the speed.
Scenario 4: Accurate enough, but too slow
If the model is too slow (and we can't improve the hardware), then we need a faster architecture. Select another architecture to try that has a faster relative inference speed. This could be the same family and smaller (though choosing a smaller model in the same family risks losing accuracy) or from a different family (rolling the dice on accuracy—might be comparable, worse, or better). If you want to take the model off platform (e.g., to an edge device), then you also have the option of downloading the model as an ONNX model. Depending on the model architecture, the conversion to ONNX is more or less successful. ONNX models typicaly run a bit faster and can (potentially) be optimized (e.g., for a GPU) using packages such as TensorRT; the speed-ups vary but can be significant (up to 10x speed-ups are advertised in the best case; we've observed typical speed-ups in the 2-4x range).
Dealing with disappointment
While model architecture and the learning (optimizer) settings can have a very large impact on the quality of your model, data quality and quantity will also significantly impact your model's performance. If you have iterated through the above steps a reasonable number of times and your model remains insufficiently accurate, then it is worth the effort to carefully evaluate your dataset and explore the possibility of expanding the dataset with more high-quality data. Issues related to data preparation (erroneous labels, particularities of image preprocessing, test set contamination, etc.) frequently find a way to creep into datasets. Identifying and mitigating these issues will likely improve the quality of your models.
Data: Classification Models
Below is an illustration of some of the important aspects of the classification models currently supported in Chariot
(through Chariot's torchvision
wrapper).
Models are grouped (as much as possible) into families of similar architecture and are ordered, roughly, in the order
the models appeared in literature.
Note 1: Relative inference time was computed using images that were 224x224 pixels (RGB), with batch size of 8, using CPU-only inference.
Each inference time is relative to the fastest model in the catalog (a shufflenet
model).
Depending on hardware availability (e.g., if a GPU is available, or which one), the image sizes,
and the batch size, the relative inference times may vary.
However, these charts should provide good heuristics.
Note 2: There are many different ways of counting layers in a neural network (specifically, what kinds of things count as a layer). We count as a layer any function in the neural network that optionally includes a bias term: specifically, convolution layers and dense (linear) layers. Additionally, as a simplifying process, we equate the number of layers with the depth of the model even though in some cases some of the layers are parallel to each other. This means that we occassionally somewhat overstate the depth of the network. Again, despite the occasional overstatement, the heuristics are good.
Individual early models
VGG family of models
DenseNet family of models
ResNet family of models
Wide ResNet family of models
ResNeXt family of models
RegNet family of models
**Note: ** The regnet_y_128gf
model is very large—currently the largest in our catalog.
Checkpoints for this model will be more than 2.5 GB each.
Before using this model, ensure that disk space isn't an issue,
consider less frequent checkpointing, or plan to manage the checkpoints through future deletion or cleanup.
ConvNext family of models
MNAS (Mobile Neural Architecture Search) family of models
ShuffleNet family of models
MobileNet family of models
EfficientNet family of models
SqueezeNet family of models
Vision Transformer models (ViT)
Note: The vision transformer family of models has some unique constraints around the expected size and shape of input images.
The architecture creates tiles from an input image that are either 14x14, 16x16, or 32x32.
The name of the architecture includes the size of the tile; for example, vit_b_16
creates 16x16 tiles.
The tiles must evenly partition the image; therefore, the image size must be a multiple of this tile size.
That means image sizes must be a multiple of 16 for vit_b_16
, a multiple of 32 for vit_b_32
, or a multiple of 14 for vit_h_14
.
In order to simplify, we constrain images to any of these models to be 224x224 (which, by design, is a multiple of both 16, 32, and 14).
The user need not supply images that are already sized to 224x224.
During training and inference, Chariot will automatically resize input images to this size.
Note: By default, center cropping is turned off in Chariot. However, if the user opts to enable center cropping, the user must ensure that the center crop size is a multiple of the tile size; otherwise, training will fail.
Note: The vision transformer family of models tends to be very large.
For example, checkpoints from vit_l_16
or vit_l_32
will be larger than 1 GB each; checkpoints from vit_h_14
will be larger than 2 GB each.
Before using one of these models, ensure that disk space isn't an issue,
consider less frequent checkpointing, or plan to manage the checkpoints through future deletion or cleanup.
Shifted-Window Transformer Models (SWin)
Data: Detection Models
Below is an illustration of some of the important aspects of the detection models currently supported in Chariot
(through Chariot's torchvision
wrapper).
If you inspect the names of the models carefully, you'll see they contain two familiar names: ResNet50
and MobileNetV3Large
.
These detection models use a classification model (architecture) as the backbone of the detection network, and that
name references the architecture of the backbone.
This explains why the number of layers is consistent (we measured the depth of the backbone).
The number of parameters (rounded to nearest million) is slightly inconsistent because the FasterRCNN or RetinaNet
framework that wraps the backbone includes some learned parameters.
Note 1: Relative inference time here is computed against the fastest detection model in the catalog. These numbers cannot be compared to the relative speeds of the classification models. In general, these models (detection models) are slower than their classification counterparts (or backbones).
Note 2: The fastest model in this collection is FasterRCNNMobileNetV3Large320FPN
.
There are several reasons why this model is so fast.
But, the primary reason it is so much faster than the other models is that it internally resizes images to 320x320 pixels, which is typically much smaller than the inputs to the other models.
As a result, image quality can be somewhat degraded.
If your imagery is already grainy or the objects you are trying to detect are small, then resizing the image even smaller
is not advisable.
FasterRCNN models with ResNet backbones
FasterRCNN models with MobileNet backbones
RetinaNet models with ResNet backbones
Fully Convolutional One-Stage Detection Models (FCOS)
Data: Segmentation Models
Below is an illustration of some of the important aspects of the segmentation models currently supported in Chariot
(through Chariot's torchvision
wrapper).
Just like the detection models, if you inspect the names of the models carefully, you'll see they contain two familiar
names: resnet50
and mobilenet_v3_large
.
These segmentation models use a classification model (architecture) as the backbone of the network, and that
name references the architecture of the backbone.
This is again the reason for the consistency in the number of parameters and layers across models.
Note 1: Relative inference time here is computed against the fastest segmentation model in the catalog. These numbers cannot be compared to the relative speeds of the classification or detection models. In general, these models (segmentation models) are slower than their classification counterparts (or backbones).
Reference Tables
The following table lists all the models in our catalog ordered by the number of parameters in the model.
- Image Classification
- Object Detection
- Image Segmentation
Model Name | Architecture | Size | Number of Trainable Parameters | Memory Footprint |
---|---|---|---|---|
squeezenet1_1 | squeezenet1_1 (Torchvision) | Small | 1,245,506 | 4 Mb |
squeezenet1_0 | squeezenet1_0 (Torchvision) | Small | 1,258,434 | 4 Mb |
shufflenet_v2_x0_5 | shufflenet_v2_x0_5 (Torchvision) | Small | 1,376,802 | 5 Mb |
mnasnet0_5 | mnasnet0_5 (Torchvision) | Small | 2,228,522 | 8 Mb |
shufflenet_v2_x1_0 | shufflenet_v2_x1_0 (Torchvision) | Small | 2,288,614 | 8 Mb |
mobilenet_v3_small | mobilenet_v3_small (Torchvision) | Small | 2,552,866 | 9 Mb |
mnasnet0_75 | mnasnet0_75 (Torchvision) | Small | 3,180,218 | 12 Mb |
shufflenet_v2_x1_5 | shufflenet_v2_x1_5 (Torchvision) | Small | 3,513,634 | 13 Mb |
mobilenet_v2 | mobilenet_v2 (Torchvision) | Small | 3,514,882 | 13 Mb |
regnet_y_400mf | regnet_y_400mf (Torchvision) | Small | 4,354,154 | 16 Mb |
mnasnet1_0 | mnasnet1_0 (Torchvision) | Small | 4,393,322 | 16 Mb |
efficientnet_b0 | efficientnet_b0 (Torchvision) | Small | 5,298,558 | 20 Mb |
mobilenet_v3_large | mobilenet_v3_large (Torchvision) | Small | 5,493,042 | 21 Mb |
regnet_x_400mf | regnet_x_400mf (Torchvision) | Small | 5,505,986 | 21 Mb |
mnasnet1_3 | mnasnet1_3 (Torchvision) | Small | 6,292,266 | 24 Mb |
regnet_y_800mf | regnet_y_800mf (Torchvision) | Small | 6,442,522 | 24 Mb |
regnet_x_800mf | regnet_x_800mf (Torchvision) | Small | 7,269,666 | 27 Mb |
shufflenet_v2_x2_0 | shufflenet_v2_x2_0 (Torchvision) | Small | 7,404,006 | 28 Mb |
efficientnet_b1 | efficientnet_b1 (Torchvision) | Small | 7,804,194 | 30 Mb |
densenet121 | densenet121 (Torchvision) | Small | 7,988,866 | 30 Mb |
efficientnet_b2 | efficientnet_b2 (Torchvision) | Small | 9,120,004 | 35 Mb |
regnet_x_1_6gf | regnet_x_1_6gf (Torchvision) | Small | 9,200,146 | 35 Mb |
regnet_y_1_6gf | regnet_y_1_6gf (Torchvision) | Medium | 11,212,440 | 42 Mb |
resnet18 | resnet18 (Torchvision) | Medium | 11,699,522 | 44 Mb |
efficientnet_b3 | efficientnet_b3 (Torchvision) | Medium | 12,243,242 | 47 Mb |
googlenet | googlenet (Torchvision) | Medium | 13,014,898 | 49 Mb |
densenet169 | densenet169 (Torchvision) | Medium | 14,159,490 | 54 Mb |
regnet_x_3_2gf | regnet_x_3_2gf (Torchvision) | Medium | 15,306,562 | 58 Mb |
efficientnet_b4 | efficientnet_b4 (Torchvision) | Medium | 19,351,626 | 74 Mb |
regnet_y_3_2gf | regnet_y_3_2gf (Torchvision) | Medium | 19,446,348 | 74 Mb |
densenet201 | densenet201 (Torchvision) | Medium | 20,023,938 | 77 Mb |
resnet34 | resnet34 (Torchvision) | Medium | 21,807,682 | 83 Mb |
resnext50_32x4d | resnext50_32x4d (Torchvision) | Medium | 25,038,914 | 95 Mb |
resnet50 | resnet50 (Torchvision) | Medium | 25,567,042 | 97 Mb |
inception_v3 | inception_v3 (Torchvision) | Medium | 27,171,274 | 103 Mb |
swin_t | swin_t (Torchvision) | Medium | 28,298,364 | 108 Mb |
convnext_tiny | convnext_tiny (Torchvision) | Medium | 28,599,138 | 109 Mb |
densenet161 | densenet161 (Torchvision) | Medium | 28,691,010 | 110 Mb |
efficientnet_b5 | efficientnet_b5 (Torchvision) | Large | 30,399,794 | 116 Mb |
regnet_y_8gf | regnet_y_8gf (Torchvision) | Large | 39,391,482 | 150 Mb |
regnet_x_8gf | regnet_x_8gf (Torchvision) | Large | 39,582,658 | 151 Mb |
efficientnet_b6 | efficientnet_b6 (Torchvision) | Large | 43,050,714 | 165 Mb |
resnet101 | resnet101 (Torchvision) | Large | 44,559,170 | 170 Mb |
swin_s | swin_s (Torchvision) | Large | 49,616,268 | 189 Mb |
convnext_small | convnext_small (Torchvision) | Large | 50,233,698 | 191 Mb |
regnet_x_16gf | regnet_x_16gf (Torchvision) | Large | 54,288,546 | 207 Mb |
resnet152 | resnet152 (Torchvision) | Large | 60,202,818 | 230 Mb |
alexnet | alexnet (Torchvision) | Large | 61,110,850 | 233 Mb |
efficientnet_b7 | efficientnet_b7 (Torchvision) | Large | 66,357,970 | 254 Mb |
wide_resnet50_2 | wide_resnet50_2 (Torchvision) | Large | 68,893,250 | 263 Mb |
resnext101_64x4d | resnext101_64x4d (Torchvision) | Large | 83,465,282 | 319 Mb |
regnet_y_16gf | regnet_y_16gf (Torchvision) | Large | 83,600,150 | 319 Mb |
vit_b_16 | vit_b_16 (Torchvision) | Large | 86,577,666 | 330 Mb |
swin_b | swin_b (Torchvision) | Large | 87,778,234 | 335 Mb |
vit_b_32 | vit_b_32 (Torchvision) | Large | 88,234,242 | 336 Mb |
convnext_base | convnext_base (Torchvision) | Large | 88,601,474 | 337 Mb |
resnext101_32x8d | resnext101_32x8d (Torchvision) | Large | 88,801,346 | 339 Mb |
regnet_x_32gf | regnet_x_32gf (Torchvision) | Large | 107,821,570 | 411 Mb |
wide_resnet101_2 | wide_resnet101_2 (Torchvision) | Large | 126,896,706 | 484 Mb |
vgg11 | vgg11 (Torchvision) | Large | 132,873,346 | 506 Mb |
vgg11_bn | vgg11_bn (Torchvision) | Large | 132,878,850 | 506 Mb |
vgg13 | vgg13 (Torchvision) | Large | 133,057,858 | 507 Mb |
vgg13_bn | vgg13_bn (Torchvision) | Large | 133,063,746 | 507 Mb |
vgg16 | vgg16 (Torchvision) | Large | 138,367,554 | 527 Mb |
vgg16_bn | vgg16_bn (Torchvision) | Large | 138,376,002 | 527 Mb |
vgg19 | vgg19 (Torchvision) | Large | 143,677,250 | 548 Mb |
vgg19_bn | vgg19_bn (Torchvision) | Large | 143,688,258 | 548 Mb |
regnet_y_32gf | regnet_y_32gf (Torchvision) | Large | 145,056,780 | 553 Mb |
convnext_large | convnext_large (Torchvision) | Large | 197,777,346 | 754 Mb |
vit_l_16 | vit_l_16 (Torchvision) | Large | 304,336,642 | 1160 Mb |
vit_l_32 | vit_l_32 (Torchvision) | Large | 306,545,410 | 1169 Mb |
vit_h_14 | vit_h_14 (Torchvision) | Large | 632,055,810 | 2411 Mb |
regnet_y_128gf | regnet_y_128gf (Torchvision) | Large | 644,822,904 | 2461 Mb |
Model Name | Architecture | Size | Number of Trainable Parameters | Memory Footprint |
---|---|---|---|---|
YOLOv8_nano | YOLOv8 - Nano | Small | 3,012,782 | 11 Mb |
YOLOv8_small | YOLOv8 - Small | Medium | 11,139,454 | 42 Mb |
FasterRCNNMobileNetV3SmallFPN | Faster R-CNN with mobilenet_v3_small backbone | Medium | 16,823,453 | 64 Mb |
FasterRCNNMobileNetV3LargeFPN | Faster R-CNN with mobilenet_v3_large backbone | Medium | 18,970,397 | 72 Mb |
FasterRCNNMobileNetV3Large320FPN | Faster R-CNN with mobilenet_v3_large backbone | Medium | 18,970,397 | 72 Mb |
FCOSResnet18FPN | Fully Convolutional One Stage (FCOS) with resnet18 backbone | Medium | 19,106,767 | 72 Mb |
RetinaNetResnet18FPN | RetinaNet with resnet18 backbone | Medium | 19,358,526 | 73 Mb |
YOLOv8_medium | YOLOv8 - Medium | Medium | 25,862,094 | 98 Mb |
FasterRCNNResnet18FPN | Faster R-CNN with resnet18 backbone | Medium | 28,314,881 | 108 Mb |
FCOSResnet34FPN | Fully Convolutional One Stage (FCOS) with resnet34 backbone | Medium | 29,207,503 | 111 Mb |
RetinaNetResnet34FPN | RetinaNet with resnet34 backbone | Medium | 29,459,262 | 112 Mb |
FCOSResnet50FPN | Fully Convolutional One Stage (FCOS) with resnet50 backbone | Large | 32,082,895 | 122 Mb |
RetinaNetResnet50FPN | RetinaNet with resnet50 backbone | Large | 32,334,654 | 123 Mb |
FasterRCNNResnet34FPN | Faster R-CNN with resnet34 backbone | Large | 38,415,617 | 146 Mb |
FasterRCNNResnet50FPN | Faster R-CNN with resnet50 backbone | Large | 41,340,161 | 158 Mb |
YOLOv8_large | YOLOv8 - Large | Large | 43,637,534 | 166 Mb |
FCOSResnet101FPN | Fully Convolutional One Stage (FCOS) with resnet101 backbone | Large | 51,022,799 | 195 Mb |
RetinaNetResnet101FPN | RetinaNet with resnet101 backbone | Large | 51,274,558 | 196 Mb |
FasterRCNNResnet101FPN | Faster R-CNN with resnet101 backbone | Large | 60,280,065 | 230 Mb |
YOLOv8_xl | YOLOv8 - Extra large | Large | 68,162,222 | 260 Mb |
Model Name | Architecture | Size | Number of Trainable Parameters | Memory Footprint |
---|---|---|---|---|
lraspp_mobilenet_v3_large | lraspp_mobilenet_v3_large (Torchvision) | Small | 3,219,668 | 12 Mb |
deeplabv3_mobilenet_v3_large | deeplabv3_mobilenet_v3_large (Torchvision) | Medium | 11,022,650 | 42 Mb |
fcn_resnet50 | fcn_resnet50 (Torchvision) | Large | 32,951,370 | 125 Mb |
deeplabv3_resnet50 | deeplabv3_resnet50 (Torchvision) | Large | 39,636,042 | 151 Mb |
fcn_resnet101 | fcn_resnet101 (Torchvision) | Large | 51,943,498 | 198 Mb |
deeplabv3_resnet101 | deeplabv3_resnet101 (Torchvision) | Large | 58,628,170 | 224 Mb |