Skip to main content

Neural Network Architecture

The architecture of a neural network consists of a complete mathematical description of the series of transformations that data undergoes as it passes through the neural network. The architecture determines what a model is able to learn or represent, and it therefore has a direct impact on the resulting accuracy of a model trained on any specific dataset. Unfortunately, it is (currently) impossible to describe what any particular architecture can learn or to precisely determine the representational needs of most datasets (especially image- or text-based datasets). Consequently, the process of finding an appropriate architecture ultimately requires experimentation.

The literature is full of strategies for finding or designing a suitable neural network architecture for any given dataset—broadly called Neural Architecture Search (NAS). These strategies are varyingly complex, varyingly parallelizable, varyingly hardware or time intensive, and varyingly successful. In this section, we aim to provide a strategy that anyone can follow and, with some luck, find a sufficiently good architecture that can be trained and become a suitable model.

A Simple Architecture Search Strategy

The following is a practitioner's guide to finding a suitable architecture. Before beginning, you should have a good idea of what suitable means for your application. Generally, suitable means satisfying some constraints: What is the minimum accuracy required to be useful? Or, how fast does the model need to be on my hardware? Each experiment (training run) you do with an architecture will provide you with data points that inform your future experiments. Conveniently, Chariot will track all your training runs so you can come back at any time to runs you have already completed to review or compare.

Step 1: Just pick an architecture to start.

Pick a starting architecture, and try training the model. Chariot provides good defaults for a first choice of models. For classification tasks, the default choice is ResNet18, which is the smallest model in the ResNet Family. It is generally recommended to start small as the training will be faster and will allow you to get your first data point(s) quicker.

Step 2: Evaluate the results.

At each step, it is important to evaluate the suitability of the model and use that information to guide further iterations. Training run metrics are computed and stored periodically during the run—users can specify the frequency at which these metrics are produced. Metrics are available live during the run (for monitoring) on the Training Run > Metrics page. After a successful training run and evaluation of the model, you will most likely be in one of the following scenarios (determine which, and follow the recommendation).

Scenario 1: Suitable

In this scenario, the model meets or exceeds all your suitability requirements. You can stop here and just take the model you have now trained. Or, you could try to find a better model. If better means faster, you could experiment with architectures whose relative inference time is lower than this model—generally, either smaller models in the same family or trying another family of models. If better means more accurate, you could try training the model longer, adding data (and training longer), or training a larger model. If you choose to train a larger model, try to stay in the same family and incrementally increase the size/depth of the model, e.g., substitute a ResNet34 model for a ResNet18 model (instead of jumping from ResNet18 all the way to a ResNet152 model). Then, train and reevaluate.

Scenario 2: Underfit

In this case, your model failed to learn at all, or it simply didn't learn enough (the accuracy is too low). There are several options to consider, any of which may produce a better model. Retry the training with the same architecture but different optimizer settings.

You might adjust the learning rate. If your batch size was smaller than the default (32), you probably want to decrease the default learning rate (from 0.001 to maybe 0.0005). Likewise, if the batch size is very large (say, for example, larger than 256 or 512), then you may need to increase the learning rate from the default. If the training metrics look very noisy (frequent large variation in the training loss), this can indicate that your model is being adjusted too rapidly. In that case, reducing the learning rate is also likely to be helpful. If the metrics are nearly constant, then that could indicate that your model is being adjusted too slowly. In that case, increasing the learning rate may be helpful.

If adjustments to the learning rate don't yield better results, then consider other optimizer settings as well. You might reduce the weight decay (if you had any), or try a different optimizer all together (Adam is a good default, but in some situations others may work better). You might even retry the training run (briefly) with identical settings—especially if this training run was "from scratch" and not starting from any pre-trained weights. If you've tried some of the above with no success, then try a larger model (either more parameters or more layers). Best practices are to incrementally increase the size of the model inside the same family before experimenting with models from other families. If your model is also too slow for your suitability requirements, then getting a larger model from the same family is not advised (as this will be even slower).

Scenario 3: Overfit

In this case, your model performs well (very small loss) on the training data but poorly on the test/validation data. If at some point during the training process the model was good on both the training and test/validation data, then you could simply use the model from an earlier checkpoint (less training, before it overfit) if that accuracy is suitable. In that case, the earlier checkpoint puts you into scenario 1. If the validation accuarcy was never good, then your model overfit to the training data too fast. The two best options to consider here are either using a smaller architecture (fewer parameters or fewer layers)—again first from the same family but alternatively from another family—or increasing the weight decay in the training run settings. If the model is also too slow, then opt first for a smaller architecture over increasing the weight decay as the latter will have no noticeable impact on the speed.

Scenario 4: Accurate enough, but too slow

If the model is too slow (and we can't improve the hardware), then we need a faster architecture. Select another architecture to try that has a faster relative inference speed. This could be the same family and smaller (though choosing a smaller model in the same family risks losing accuracy) or from a different family (rolling the dice on accuracy—might be comparable, worse, or better). If you want to take the model off platform (e.g., to an edge device), then you also have the option of downloading the model as an ONNX model. Depending on the model architecture, the conversion to ONNX is more or less successful. ONNX models typicaly run a bit faster and can (potentially) be optimized (e.g., for a GPU) using packages such as TensorRT; the speed-ups vary but can be significant (up to 10x speed-ups are advertised in the best case; we've observed typical speed-ups in the 2-4x range).

Dealing with disappointment

While model architecture and the learning (optimizer) settings can have a very large impact on the quality of your model, data quality and quantity will also significantly impact your model's performance. If you have iterated through the above steps a reasonable number of times and your model remains insufficiently accurate, then it is worth the effort to carefully evaluate your dataset and explore the possibility of expanding the dataset with more high-quality data. Issues related to data preparation (erroneous labels, particularities of image preprocessing, test set contamination, etc.) frequently find a way to creep into datasets. Identifying and mitigating these issues will likely improve the quality of your models.

Data: Classification Models

Below is an illustration of some of the important aspects of the classification models currently supported in Chariot (through Chariot's torchvision wrapper). Models are grouped (as much as possible) into families of similar architecture and are ordered, roughly, in the order the models appeared in literature.

Note 1: Relative inference time was computed using images that were 224x224 pixels (RGB), with batch size of 8, using CPU-only inference. Each inference time is relative to the fastest model in the catalog (a shufflenet model). Depending on hardware availability (e.g., if a GPU is available, or which one), the image sizes, and the batch size, the relative inference times may vary. However, these charts should provide good heuristics.

Note 2: There are many different ways of counting layers in a neural network (specifically, what kinds of things count as a layer). We count as a layer any function in the neural network that optionally includes a bias term: specifically, convolution layers and dense (linear) layers. Additionally, as a simplifying process, we equate the number of layers with the depth of the model even though in some cases some of the layers are parallel to each other. This means that we occassionally somewhat overstate the depth of the network. Again, despite the occasional overstatement, the heuristics are good.

Individual early models

bar plots of stats for individual early models

VGG family of models

bar plots of stats for VGG models

DenseNet family of models

bar plots of stats for densenet models

ResNet family of models

bar plots of stats for Resnet models

Wide ResNet family of models

bar plots of stats for Wide Resnet models

ResNeXt family of models

bar plots of stats for Resnext models

RegNet family of models

bar plots of stats for RegNet models

bar plots of stats for RegNet models

**Note: ** The regnet_y_128gf model is very large—currently the largest in our catalog. Checkpoints for this model will be more than 2.5 GB each. Before using this model, ensure that disk space isn't an issue, consider less frequent checkpointing, or plan to manage the checkpoints through future deletion or cleanup.

ConvNext family of models

bar plots of stats for convnext models

MNAS (Mobile Neural Architecture Search) family of models

bar plots of stats for MNAS models

ShuffleNet family of models

bar plots of stats for shufflenet models

MobileNet family of models

bar plots of stats for Mobilenet models

EfficientNet family of models

bar plots of stats for efficient net models

SqueezeNet family of models

bar plots of stats for SqueezeNet models

Vision Transformer models (ViT)

bar plots of stats for vision transformer models

Note: The vision transformer family of models has some unique constraints around the expected size and shape of input images. The architecture creates tiles from an input image that are either 14x14, 16x16, or 32x32. The name of the architecture includes the size of the tile; for example, vit_b_16 creates 16x16 tiles. The tiles must evenly partition the image; therefore, the image size must be a multiple of this tile size. That means image sizes must be a multiple of 16 for vit_b_16, a multiple of 32 for vit_b_32, or a multiple of 14 for vit_h_14. In order to simplify, we constrain images to any of these models to be 224x224 (which, by design, is a multiple of both 16, 32, and 14). The user need not supply images that are already sized to 224x224. During training and inference, Chariot will automatically resize input images to this size.

Note: By default, center cropping is turned off in Chariot. However, if the user opts to enable center cropping, the user must ensure that the center crop size is a multiple of the tile size; otherwise, training will fail.

Note: The vision transformer family of models tends to be very large. For example, checkpoints from vit_l_16 or vit_l_32 will be larger than 1 GB each; checkpoints from vit_h_14 will be larger than 2 GB each. Before using one of these models, ensure that disk space isn't an issue, consider less frequent checkpointing, or plan to manage the checkpoints through future deletion or cleanup.

Shifted-Window Transformer Models (SWin)

bar plots of stats for shifted window transformer models

Data: Detection Models

Below is an illustration of some of the important aspects of the detection models currently supported in Chariot (through Chariot's torchvision wrapper). If you inspect the names of the models carefully, you'll see they contain two familiar names: ResNet50 and MobileNetV3Large. These detection models use a classification model (architecture) as the backbone of the detection network, and that name references the architecture of the backbone. This explains why the number of layers is consistent (we measured the depth of the backbone). The number of parameters (rounded to nearest million) is slightly inconsistent because the FasterRCNN or RetinaNet framework that wraps the backbone includes some learned parameters.

Note 1: Relative inference time here is computed against the fastest detection model in the catalog. These numbers cannot be compared to the relative speeds of the classification models. In general, these models (detection models) are slower than their classification counterparts (or backbones).

Note 2: The fastest model in this collection is FasterRCNNMobileNetV3Large320FPN. There are several reasons why this model is so fast. But, the primary reason it is so much faster than the other models is that it internally resizes images to 320x320 pixels, which is typically much smaller than the inputs to the other models. As a result, image quality can be somewhat degraded. If your imagery is already grainy or the objects you are trying to detect are small, then resizing the image even smaller is not advisable.

FasterRCNN models with ResNet backbones

bar plots of stats for torchvision Detection models

FasterRCNN models with MobileNet backbones

bar plots of stats for torchvision Detection models

RetinaNet models with ResNet backbones

bar plots of stats for torchvision Detection models

Fully Convolutional One-Stage Detection Models (FCOS)

bar plots of stats for torchvision FCOS models

Data: Segmentation Models

Below is an illustration of some of the important aspects of the segmentation models currently supported in Chariot (through Chariot's torchvision wrapper). Just like the detection models, if you inspect the names of the models carefully, you'll see they contain two familiar names: resnet50 and mobilenet_v3_large. These segmentation models use a classification model (architecture) as the backbone of the network, and that name references the architecture of the backbone. This is again the reason for the consistency in the number of parameters and layers across models.

Note 1: Relative inference time here is computed against the fastest segmentation model in the catalog. These numbers cannot be compared to the relative speeds of the classification or detection models. In general, these models (segmentation models) are slower than their classification counterparts (or backbones).

bar plot of stats for torchvision Segmentation models

Reference Tables

The following table lists all the models in our catalog ordered by the number of parameters in the model.

Model NameArchitectureSizeNumber of Trainable ParametersMemory Footprint
squeezenet1_1squeezenet1_1 (Torchvision)Small1,245,5064 Mb
squeezenet1_0squeezenet1_0 (Torchvision)Small1,258,4344 Mb
shufflenet_v2_x0_5shufflenet_v2_x0_5 (Torchvision)Small1,376,8025 Mb
mnasnet0_5mnasnet0_5 (Torchvision)Small2,228,5228 Mb
shufflenet_v2_x1_0shufflenet_v2_x1_0 (Torchvision)Small2,288,6148 Mb
mobilenet_v3_smallmobilenet_v3_small (Torchvision)Small2,552,8669 Mb
mnasnet0_75mnasnet0_75 (Torchvision)Small3,180,21812 Mb
shufflenet_v2_x1_5shufflenet_v2_x1_5 (Torchvision)Small3,513,63413 Mb
mobilenet_v2mobilenet_v2 (Torchvision)Small3,514,88213 Mb
regnet_y_400mfregnet_y_400mf (Torchvision)Small4,354,15416 Mb
mnasnet1_0mnasnet1_0 (Torchvision)Small4,393,32216 Mb
efficientnet_b0efficientnet_b0 (Torchvision)Small5,298,55820 Mb
mobilenet_v3_largemobilenet_v3_large (Torchvision)Small5,493,04221 Mb
regnet_x_400mfregnet_x_400mf (Torchvision)Small5,505,98621 Mb
mnasnet1_3mnasnet1_3 (Torchvision)Small6,292,26624 Mb
regnet_y_800mfregnet_y_800mf (Torchvision)Small6,442,52224 Mb
regnet_x_800mfregnet_x_800mf (Torchvision)Small7,269,66627 Mb
shufflenet_v2_x2_0shufflenet_v2_x2_0 (Torchvision)Small7,404,00628 Mb
efficientnet_b1efficientnet_b1 (Torchvision)Small7,804,19430 Mb
densenet121densenet121 (Torchvision)Small7,988,86630 Mb
efficientnet_b2efficientnet_b2 (Torchvision)Small9,120,00435 Mb
regnet_x_1_6gfregnet_x_1_6gf (Torchvision)Small9,200,14635 Mb
regnet_y_1_6gfregnet_y_1_6gf (Torchvision)Medium11,212,44042 Mb
resnet18resnet18 (Torchvision)Medium11,699,52244 Mb
efficientnet_b3efficientnet_b3 (Torchvision)Medium12,243,24247 Mb
googlenetgooglenet (Torchvision)Medium13,014,89849 Mb
densenet169densenet169 (Torchvision)Medium14,159,49054 Mb
regnet_x_3_2gfregnet_x_3_2gf (Torchvision)Medium15,306,56258 Mb
efficientnet_b4efficientnet_b4 (Torchvision)Medium19,351,62674 Mb
regnet_y_3_2gfregnet_y_3_2gf (Torchvision)Medium19,446,34874 Mb
densenet201densenet201 (Torchvision)Medium20,023,93877 Mb
resnet34resnet34 (Torchvision)Medium21,807,68283 Mb
resnext50_32x4dresnext50_32x4d (Torchvision)Medium25,038,91495 Mb
resnet50resnet50 (Torchvision)Medium25,567,04297 Mb
inception_v3inception_v3 (Torchvision)Medium27,171,274103 Mb
swin_tswin_t (Torchvision)Medium28,298,364108 Mb
convnext_tinyconvnext_tiny (Torchvision)Medium28,599,138109 Mb
densenet161densenet161 (Torchvision)Medium28,691,010110 Mb
efficientnet_b5efficientnet_b5 (Torchvision)Large30,399,794116 Mb
regnet_y_8gfregnet_y_8gf (Torchvision)Large39,391,482150 Mb
regnet_x_8gfregnet_x_8gf (Torchvision)Large39,582,658151 Mb
efficientnet_b6efficientnet_b6 (Torchvision)Large43,050,714165 Mb
resnet101resnet101 (Torchvision)Large44,559,170170 Mb
swin_sswin_s (Torchvision)Large49,616,268189 Mb
convnext_smallconvnext_small (Torchvision)Large50,233,698191 Mb
regnet_x_16gfregnet_x_16gf (Torchvision)Large54,288,546207 Mb
resnet152resnet152 (Torchvision)Large60,202,818230 Mb
alexnetalexnet (Torchvision)Large61,110,850233 Mb
efficientnet_b7efficientnet_b7 (Torchvision)Large66,357,970254 Mb
wide_resnet50_2wide_resnet50_2 (Torchvision)Large68,893,250263 Mb
resnext101_64x4dresnext101_64x4d (Torchvision)Large83,465,282319 Mb
regnet_y_16gfregnet_y_16gf (Torchvision)Large83,600,150319 Mb
vit_b_16vit_b_16 (Torchvision)Large86,577,666330 Mb
swin_bswin_b (Torchvision)Large87,778,234335 Mb
vit_b_32vit_b_32 (Torchvision)Large88,234,242336 Mb
convnext_baseconvnext_base (Torchvision)Large88,601,474337 Mb
resnext101_32x8dresnext101_32x8d (Torchvision)Large88,801,346339 Mb
regnet_x_32gfregnet_x_32gf (Torchvision)Large107,821,570411 Mb
wide_resnet101_2wide_resnet101_2 (Torchvision)Large126,896,706484 Mb
vgg11vgg11 (Torchvision)Large132,873,346506 Mb
vgg11_bnvgg11_bn (Torchvision)Large132,878,850506 Mb
vgg13vgg13 (Torchvision)Large133,057,858507 Mb
vgg13_bnvgg13_bn (Torchvision)Large133,063,746507 Mb
vgg16vgg16 (Torchvision)Large138,367,554527 Mb
vgg16_bnvgg16_bn (Torchvision)Large138,376,002527 Mb
vgg19vgg19 (Torchvision)Large143,677,250548 Mb
vgg19_bnvgg19_bn (Torchvision)Large143,688,258548 Mb
regnet_y_32gfregnet_y_32gf (Torchvision)Large145,056,780553 Mb
convnext_largeconvnext_large (Torchvision)Large197,777,346754 Mb
vit_l_16vit_l_16 (Torchvision)Large304,336,6421160 Mb
vit_l_32vit_l_32 (Torchvision)Large306,545,4101169 Mb
vit_h_14vit_h_14 (Torchvision)Large632,055,8102411 Mb
regnet_y_128gfregnet_y_128gf (Torchvision)Large644,822,9042461 Mb