NNVM Intermediate Representation
Do not miss this exclusive book on Binary Tree Problems. Get it now for free.
Reading time: 30 minutes
NNVM is a reusable graph Intermediate Representation stack for deep learning systems. It provides useful API to construct, represent and transform computation graphs to get most high-level optimization needed in deep learning. NNVM is a part of TVM stack for deep learning and provides a shared compiler for deep learning frameworks to optimize, compile and deploy into different hardware backends through TVM.
Key requirements
The key requirements of NNVM considering the goal of TVM Stack are:
- Have minimum dependency in the deployment module
- Being able to add new operators to the IR in a decentralized fashion
- Being able to add new optimization passes to the IR and applies to existing graphs
Key elements
The key elements of NNVM are:
- Operator registry system to register and add new operators
- Operator attribute system provide property of operator in decentralized fashion
- A reusable IR data structure for optimization passes
This design allows the NNVM compiler to be directly used as optimization and compilation stack for frameworks. The extendible nature of NNVM makes new adjustment easy without constraining the backend providers.
Example of NNVM Intermediate Representation
This is how a ResNet18 model in NNVM IR (graph representation) looks like:
Graph(%data,
%batch_norm0_gamma_mul_div_expand,
%batch_norm0_add_beta_expand,
%resnetv22_conv0_weight_OIHW3i8o,
%batch_norm1_gamma_mul_div_expand,
%batch_norm1_add_beta_expand,
%batch_norm2_gamma_mul_div_expand,
%batch_norm2_add_beta_expand,
%resnetv22_stage1_conv0_weight_OIHW8i8o,
%batch_norm3_gamma_mul_div_expand,
%batch_norm3_add_beta_expand,
%resnetv22_stage1_conv1_weight_OIHW8i8o,
%batch_norm4_gamma_mul_div_expand,
%batch_norm4_add_beta_expand,
%resnetv22_stage1_conv2_weight_OIHW8i8o,
%batch_norm5_gamma_mul_div_expand,
%batch_norm5_add_beta_expand,
%resnetv22_stage1_conv3_weight_OIHW8i8o,
%batch_norm6_gamma_mul_div_expand,
%batch_norm6_add_beta_expand,
%resnetv22_stage2_conv0_weight_OIHW8i8o,
%batch_norm7_gamma_mul_div_expand,
%batch_norm7_add_beta_expand,
%resnetv22_stage2_conv1_weight_OIHW8i8o,
%resnetv22_stage2_conv2_weight_OI8i8oHW,
%batch_norm8_gamma_mul_div_expand,
%batch_norm8_add_beta_expand,
%resnetv22_stage2_conv3_weight_OIHW8i8o,
%batch_norm9_gamma_mul_div_expand,
%batch_norm9_add_beta_expand,
%resnetv22_stage2_conv4_weight_OIHW8i8o,
%batch_norm10_gamma_mul_div_expand,
%batch_norm10_add_beta_expand,
%resnetv22_stage3_conv0_weight_OIHW8i8o,
%batch_norm11_gamma_mul_div_expand,
%batch_norm11_add_beta_expand,
%resnetv22_stage3_conv1_weight_OIHW8i8o,
%resnetv22_stage3_conv2_weight_OI8i8oHW,
%batch_norm12_gamma_mul_div_expand,
%batch_norm12_add_beta_expand,
%resnetv22_stage3_conv3_weight_OIHW8i8o,
%batch_norm13_gamma_mul_div_expand,
%batch_norm13_add_beta_expand,
%resnetv22_stage3_conv4_weight_OIHW8i8o,
%batch_norm14_gamma_mul_div_expand,
%batch_norm14_add_beta_expand,
%resnetv22_stage4_conv0_weight_OIHW8i8o,
%batch_norm15_gamma_mul_div_expand,
%batch_norm15_add_beta_expand,
%resnetv22_stage4_conv1_weight_OIHW8i8o,
%resnetv22_stage4_conv2_weight_OI8i8oHW,
%batch_norm16_gamma_mul_div_expand,
%batch_norm16_add_beta_expand,
%resnetv22_stage4_conv3_weight_OIHW8i8o,
%batch_norm17_gamma_mul_div_expand,
%batch_norm17_add_beta_expand,
%resnetv22_stage4_conv4_weight_OIHW8i8o,
%batch_norm18_gamma_mul_div_expand,
%batch_norm18_add_beta_expand,
%resnetv22_dense0_weight,
%__mul_scalar__1) {
%3 = tvm_op(%data, %batch_norm0_gamma_mul_div_expand,
%batch_norm0_add_beta_expand, num_outputs='1', num_inputs='3',
func_name='fuse_broadcast_mul_broadcast_add___layout_transform__',
flatten_data='0')
%7 = tvm_op(%3, %resnetv22_conv0_weight_OIHW3i8o,
%batch_norm1_gamma_mul_div_expand, %batch_norm1_add_beta_expand,
num_outputs='1', num_inputs='4',
func_name='fuse__contrib_conv2d_NCHWc_broadcast_mul_broadcast_add_relu',
flatten_data='0')
%8 = tvm_op(%7, num_outputs='1', num_inputs='1',
func_name='fuse_max_pool2d', flatten_data='0')
%11 = tvm_op(%8, %batch_norm2_gamma_mul_div_expand,
%batch_norm2_add_beta_expand, num_outputs='1', num_inputs='3',
func_name='fuse_broadcast_mul_broadcast_add_relu', flatten_data='0')
%15 = tvm_op(%11, %resnetv22_stage1_conv0_weight_OIHW8i8o,
%batch_norm3_gamma_mul_div_expand, %batch_norm3_add_beta_expand,
num_outputs='1', num_inputs='4',
func_name='fuse__contrib_conv2d_NCHWc_broadcast_mul_broadcast_add_relu_1',
flatten_data='0')
%17 = tvm_op(%15, %resnetv22_stage1_conv1_weight_OIHW8i8o, %8,
num_outputs='1', num_inputs='3',
func_name='fuse__contrib_conv2d_NCHWc_elemwise_add', flatten_data='0')
%20 = tvm_op(%17, %batch_norm4_gamma_mul_div_expand, %batch_norm4_add_beta_expand,
num_outputs='1', num_inputs='3',
func_name='fuse_broadcast_mul_broadcast_add_relu', flatten_data='0')
%24 = tvm_op(%20, %resnetv22_stage1_conv2_weight_OIHW8i8o,
%batch_norm5_gamma_mul_div_expand, %batch_norm5_add_beta_expand,
num_outputs='1', num_inputs='4',
func_name='fuse__contrib_conv2d_NCHWc_broadcast_mul_broadcast_add_relu_1',
flatten_data='0')
%28 = tvm_op(%24, %resnetv22_stage1_conv3_weight_OIHW8i8o, %17,
%batch_norm6_gamma_mul_div_expand, %batch_norm6_add_beta_expand,
num_outputs='1', num_inputs='5',
func_name='fuse__contrib_conv2d_NCHWc_elemwise_add_
broadcast_mul_broadcast_add_relu', flatten_data='0')
%32 = tvm_op(%28, %resnetv22_stage2_conv0_weight_OIHW8i8o,
%batch_norm7_gamma_mul_div_expand, %batch_norm7_add_beta_expand,
num_outputs='1', num_inputs='4',
func_name='fuse__contrib_conv2d_NCHWc_broadcast_mul_broadcast_add_relu_2',
flatten_data='0')
%35 = tvm_op(%28, %resnetv22_stage2_conv2_weight_OI8i8oHW, num_outputs='1',
num_inputs='2', func_name='fuse__contrib_conv2d_NCHWc', flatten_data='0')
%36 = tvm_op(%32, %resnetv22_stage2_conv1_weight_OIHW8i8o, %35,
num_outputs='1', num_inputs='3',
func_name='fuse__contrib_conv2d_NCHWc_elemwise_add_1', flatten_data='0')
%39 = tvm_op(%36, %batch_norm8_gamma_mul_div_expand,
%batch_norm8_add_beta_expand, num_outputs='1', num_inputs='3',
func_name='fuse_broadcast_mul_broadcast_add_relu_1', flatten_data='0')
%43 = tvm_op(%39, %resnetv22_stage2_conv3_weight_OIHW8i8o,
%batch_norm9_gamma_mul_div_expand, %batch_norm9_add_beta_expand,
num_outputs='1', num_inputs='4',
func_name='fuse__contrib_conv2d_NCHWc_broadcast_mul_broadcast_add_relu_3',
flatten_data='0')
%47 = tvm_op(%43, %resnetv22_stage2_conv4_weight_OIHW8i8o, %36,
%batch_norm10_gamma_mul_div_expand, %batch_norm10_add_beta_expand,
num_outputs='1', num_inputs='5',
func_name='fuse__contrib_conv2d_NCHWc_elemwise_add_
broadcast_mul_broadcast_add_relu_1', flatten_data='0')
%51 = tvm_op(%47, %resnetv22_stage3_conv0_weight_OIHW8i8o,
%batch_norm11_gamma_mul_div_expand, %batch_norm11_add_beta_expand,
num_outputs='1', num_inputs='4',
func_name='fuse__contrib_conv2d_NCHWc_broadcast_mul_
broadcast_add_relu_4', flatten_data='0')
%54 = tvm_op(%47, %resnetv22_stage3_conv2_weight_OI8i8oHW,
num_outputs='1', num_inputs='2', func_name='fuse__contrib_conv2d_NCHWc_1',
flatten_data='0')
%55 = tvm_op(%51, %resnetv22_stage3_conv1_weight_OIHW8i8o,
%54, num_outputs='1', num_inputs='3',
func_name='fuse__contrib_conv2d_NCHWc_elemwise_add_2',
flatten_data='0')
%58 = tvm_op(%55, %batch_norm12_gamma_mul_div_expand,
%batch_norm12_add_beta_expand, num_outputs='1', num_inputs='3',
func_name='fuse_broadcast_mul_broadcast_add_relu_2', flatten_data='0')
%62 = tvm_op(%58, %resnetv22_stage3_conv3_weight_OIHW8i8o,
%batch_norm13_gamma_mul_div_expand, %batch_norm13_add_beta_expand,
num_outputs='1', num_inputs='4',
func_name='fuse__contrib_conv2d_NCHWc_broadcast_mul_broadcast_add_relu_5',
flatten_data='0')
%66 = tvm_op(%62, %resnetv22_stage3_conv4_weight_OIHW8i8o, %55,
%batch_norm14_gamma_mul_div_expand, %batch_norm14_add_beta_expand,
num_outputs='1', num_inputs='5',
func_name='fuse__contrib_conv2d_NCHWc_elemwise_add_broadcast_mul_
broadcast_add_relu_2', flatten_data='0')
%70 = tvm_op(%66, %resnetv22_stage4_conv0_weight_OIHW8i8o,
%batch_norm15_gamma_mul_div_expand, %batch_norm15_add_beta_expand,
num_outputs='1', num_inputs='4',
func_name='fuse__contrib_conv2d_NCHWc_broadcast_mul_broadcast_add_relu_6',
flatten_data='0')
%73 = tvm_op(%66, %resnetv22_stage4_conv2_weight_OI8i8oHW, num_outputs='1',
num_inputs='2', func_name='fuse__contrib_conv2d_NCHWc_2', flatten_data='0')
%74 = tvm_op(%70, %resnetv22_stage4_conv1_weight_OIHW8i8o, %73, num_outputs='1',
num_inputs='3', func_name='fuse__contrib_conv2d_NCHWc_elemwise_add_3',
flatten_data='0')
%77 = tvm_op(%74, %batch_norm16_gamma_mul_div_expand,
%batch_norm16_add_beta_expand, num_outputs='1', num_inputs='3',
func_name='fuse_broadcast_mul_broadcast_add_relu_3', flatten_data='0')
%81 = tvm_op(%77, %resnetv22_stage4_conv3_weight_OIHW8i8o,
%batch_norm17_gamma_mul_div_expand, %batch_norm17_add_beta_expand,
num_outputs='1', num_inputs='4',
func_name='fuse__contrib_conv2d_NCHWc_broadcast_mul_broadcast_add_relu_7',
flatten_data='0')
%85 = tvm_op(%81, %resnetv22_stage4_conv4_weight_OIHW8i8o,
%74, %batch_norm18_gamma_mul_div_expand, %batch_norm18_add_beta_expand,
num_outputs='1', num_inputs='5',
func_name='fuse__contrib_conv2d_NCHWc_elemwise_add_
broadcast_mul_broadcast_add_relu_3', flatten_data='0')
%86 = tvm_op(%85, num_outputs='1', num_inputs='1',
func_name='fuse_global_avg_pool2d', flatten_data='0')
%87 = tvm_op(%86, num_outputs='1', num_inputs='1',
func_name='fuse___layout_transform___reshape_flatten___mul_scalar__',
flatten_data='0')
%90 = tvm_op(%87, %resnetv22_dense0_weight,
%__mul_scalar__1, num_outputs='1', num_inputs='3',
func_name='fuse_dense', flatten_data='0')
ret %90
}
graph_attr_keys = [storage_id, shape, dtype, dltype]
Sign up for FREE 3 months of Amazon Music. YOU MUST NOT MISS.