NNVM Intermediate Representation

Reading time: 30 minutes

NNVM is a reusable graph Intermediate Representation stack for deep learning systems. It provides useful API to construct, represent and transform computation graphs to get most high-level optimization needed in deep learning. NNVM is a part of TVM stack for deep learning and provides a shared compiler for deep learning frameworks to optimize, compile and deploy into different hardware backends through TVM.

Key requirements

The key requirements of NNVM considering the goal of TVM Stack are:

  • Have minimum dependency in the deployment module
  • Being able to add new operators to the IR in a decentralized fashion
  • Being able to add new optimization passes to the IR and applies to existing graphs

Key elements

The key elements of NNVM are:

  • Operator registry system to register and add new operators
  • Operator attribute system provide property of operator in decentralized fashion
  • A reusable IR data structure for optimization passes

This design allows the NNVM compiler to be directly used as optimization and compilation stack for frameworks. The extendible nature of NNVM makes new adjustment easy without constraining the backend providers.

Example of NNVM Intermediate Representation

This is how a ResNet18 model in NNVM IR (graph representation) looks like:


 Graph(%data,
      %batch_norm0_gamma_mul_div_expand,
      %batch_norm0_add_beta_expand,
      %resnetv22_conv0_weight_OIHW3i8o,
      %batch_norm1_gamma_mul_div_expand,
      %batch_norm1_add_beta_expand,
      %batch_norm2_gamma_mul_div_expand,
      %batch_norm2_add_beta_expand,
      %resnetv22_stage1_conv0_weight_OIHW8i8o,
      %batch_norm3_gamma_mul_div_expand,
      %batch_norm3_add_beta_expand,
      %resnetv22_stage1_conv1_weight_OIHW8i8o,
      %batch_norm4_gamma_mul_div_expand,
      %batch_norm4_add_beta_expand,
      %resnetv22_stage1_conv2_weight_OIHW8i8o,
      %batch_norm5_gamma_mul_div_expand,
      %batch_norm5_add_beta_expand,
      %resnetv22_stage1_conv3_weight_OIHW8i8o,
      %batch_norm6_gamma_mul_div_expand,
      %batch_norm6_add_beta_expand,
      %resnetv22_stage2_conv0_weight_OIHW8i8o,
      %batch_norm7_gamma_mul_div_expand,
      %batch_norm7_add_beta_expand,
      %resnetv22_stage2_conv1_weight_OIHW8i8o,
      %resnetv22_stage2_conv2_weight_OI8i8oHW,
      %batch_norm8_gamma_mul_div_expand,
      %batch_norm8_add_beta_expand,
      %resnetv22_stage2_conv3_weight_OIHW8i8o,
      %batch_norm9_gamma_mul_div_expand,
      %batch_norm9_add_beta_expand,
      %resnetv22_stage2_conv4_weight_OIHW8i8o,
      %batch_norm10_gamma_mul_div_expand,
      %batch_norm10_add_beta_expand,
      %resnetv22_stage3_conv0_weight_OIHW8i8o,
      %batch_norm11_gamma_mul_div_expand,
      %batch_norm11_add_beta_expand,
      %resnetv22_stage3_conv1_weight_OIHW8i8o,
      %resnetv22_stage3_conv2_weight_OI8i8oHW,
      %batch_norm12_gamma_mul_div_expand,
      %batch_norm12_add_beta_expand,
      %resnetv22_stage3_conv3_weight_OIHW8i8o,
      %batch_norm13_gamma_mul_div_expand,
      %batch_norm13_add_beta_expand,
      %resnetv22_stage3_conv4_weight_OIHW8i8o,
      %batch_norm14_gamma_mul_div_expand,
      %batch_norm14_add_beta_expand,
      %resnetv22_stage4_conv0_weight_OIHW8i8o,
      %batch_norm15_gamma_mul_div_expand,
      %batch_norm15_add_beta_expand,
      %resnetv22_stage4_conv1_weight_OIHW8i8o,
      %resnetv22_stage4_conv2_weight_OI8i8oHW,
      %batch_norm16_gamma_mul_div_expand,
      %batch_norm16_add_beta_expand,
      %resnetv22_stage4_conv3_weight_OIHW8i8o,
      %batch_norm17_gamma_mul_div_expand,
      %batch_norm17_add_beta_expand,
      %resnetv22_stage4_conv4_weight_OIHW8i8o,
      %batch_norm18_gamma_mul_div_expand,
      %batch_norm18_add_beta_expand,
      %resnetv22_dense0_weight,
      %__mul_scalar__1) {
  %3 = tvm_op(%data, %batch_norm0_gamma_mul_div_expand, 
  %batch_norm0_add_beta_expand, num_outputs='1', num_inputs='3', 
  func_name='fuse_broadcast_mul_broadcast_add___layout_transform__', 
  flatten_data='0')
  %7 = tvm_op(%3, %resnetv22_conv0_weight_OIHW3i8o, 
  %batch_norm1_gamma_mul_div_expand, %batch_norm1_add_beta_expand, 
  num_outputs='1', num_inputs='4', 
  func_name='fuse__contrib_conv2d_NCHWc_broadcast_mul_broadcast_add_relu', 
  flatten_data='0')
  %8 = tvm_op(%7, num_outputs='1', num_inputs='1', 
  func_name='fuse_max_pool2d', flatten_data='0')
  %11 = tvm_op(%8, %batch_norm2_gamma_mul_div_expand, 
  %batch_norm2_add_beta_expand, num_outputs='1', num_inputs='3', 
  func_name='fuse_broadcast_mul_broadcast_add_relu', flatten_data='0')
  %15 = tvm_op(%11, %resnetv22_stage1_conv0_weight_OIHW8i8o,
  %batch_norm3_gamma_mul_div_expand, %batch_norm3_add_beta_expand, 
  num_outputs='1', num_inputs='4', 
  func_name='fuse__contrib_conv2d_NCHWc_broadcast_mul_broadcast_add_relu_1',
  flatten_data='0')
  %17 = tvm_op(%15, %resnetv22_stage1_conv1_weight_OIHW8i8o, %8, 
  num_outputs='1', num_inputs='3', 
  func_name='fuse__contrib_conv2d_NCHWc_elemwise_add', flatten_data='0')
  %20 = tvm_op(%17, %batch_norm4_gamma_mul_div_expand, %batch_norm4_add_beta_expand, 
  num_outputs='1', num_inputs='3', 
  func_name='fuse_broadcast_mul_broadcast_add_relu', flatten_data='0')
  %24 = tvm_op(%20, %resnetv22_stage1_conv2_weight_OIHW8i8o, 
  %batch_norm5_gamma_mul_div_expand, %batch_norm5_add_beta_expand, 
  num_outputs='1', num_inputs='4', 
  func_name='fuse__contrib_conv2d_NCHWc_broadcast_mul_broadcast_add_relu_1', 
  flatten_data='0')
  %28 = tvm_op(%24, %resnetv22_stage1_conv3_weight_OIHW8i8o, %17, 
  %batch_norm6_gamma_mul_div_expand, %batch_norm6_add_beta_expand, 
  num_outputs='1', num_inputs='5', 
 func_name='fuse__contrib_conv2d_NCHWc_elemwise_add_
 broadcast_mul_broadcast_add_relu', flatten_data='0')
  %32 = tvm_op(%28, %resnetv22_stage2_conv0_weight_OIHW8i8o, 
  %batch_norm7_gamma_mul_div_expand, %batch_norm7_add_beta_expand, 
  num_outputs='1', num_inputs='4', 
  func_name='fuse__contrib_conv2d_NCHWc_broadcast_mul_broadcast_add_relu_2', 
  flatten_data='0')
  %35 = tvm_op(%28, %resnetv22_stage2_conv2_weight_OI8i8oHW, num_outputs='1', 
  num_inputs='2', func_name='fuse__contrib_conv2d_NCHWc', flatten_data='0')
  %36 = tvm_op(%32, %resnetv22_stage2_conv1_weight_OIHW8i8o, %35, 
  num_outputs='1', num_inputs='3', 
  func_name='fuse__contrib_conv2d_NCHWc_elemwise_add_1', flatten_data='0')
  %39 = tvm_op(%36, %batch_norm8_gamma_mul_div_expand, 
  %batch_norm8_add_beta_expand, num_outputs='1', num_inputs='3', 
  func_name='fuse_broadcast_mul_broadcast_add_relu_1', flatten_data='0')
  %43 = tvm_op(%39, %resnetv22_stage2_conv3_weight_OIHW8i8o, 
  %batch_norm9_gamma_mul_div_expand, %batch_norm9_add_beta_expand, 
  num_outputs='1', num_inputs='4', 
  func_name='fuse__contrib_conv2d_NCHWc_broadcast_mul_broadcast_add_relu_3',
  flatten_data='0')
  %47 = tvm_op(%43, %resnetv22_stage2_conv4_weight_OIHW8i8o, %36, 
  %batch_norm10_gamma_mul_div_expand, %batch_norm10_add_beta_expand, 
  num_outputs='1', num_inputs='5', 
 func_name='fuse__contrib_conv2d_NCHWc_elemwise_add_
 broadcast_mul_broadcast_add_relu_1', flatten_data='0')
  %51 = tvm_op(%47, %resnetv22_stage3_conv0_weight_OIHW8i8o, 
  %batch_norm11_gamma_mul_div_expand, %batch_norm11_add_beta_expand,
  num_outputs='1', num_inputs='4', 
  func_name='fuse__contrib_conv2d_NCHWc_broadcast_mul_
  broadcast_add_relu_4', flatten_data='0')
  %54 = tvm_op(%47, %resnetv22_stage3_conv2_weight_OI8i8oHW, 
  num_outputs='1', num_inputs='2', func_name='fuse__contrib_conv2d_NCHWc_1',
  flatten_data='0')
  %55 = tvm_op(%51, %resnetv22_stage3_conv1_weight_OIHW8i8o, 
  %54, num_outputs='1', num_inputs='3', 
  func_name='fuse__contrib_conv2d_NCHWc_elemwise_add_2', 
  flatten_data='0')
  %58 = tvm_op(%55, %batch_norm12_gamma_mul_div_expand, 
  %batch_norm12_add_beta_expand, num_outputs='1', num_inputs='3', 
  func_name='fuse_broadcast_mul_broadcast_add_relu_2', flatten_data='0')
  %62 = tvm_op(%58, %resnetv22_stage3_conv3_weight_OIHW8i8o, 
  %batch_norm13_gamma_mul_div_expand, %batch_norm13_add_beta_expand, 
  num_outputs='1', num_inputs='4', 
  func_name='fuse__contrib_conv2d_NCHWc_broadcast_mul_broadcast_add_relu_5', 
  flatten_data='0')
  %66 = tvm_op(%62, %resnetv22_stage3_conv4_weight_OIHW8i8o, %55, 
  %batch_norm14_gamma_mul_div_expand, %batch_norm14_add_beta_expand,
  num_outputs='1', num_inputs='5', 
  func_name='fuse__contrib_conv2d_NCHWc_elemwise_add_broadcast_mul_
  broadcast_add_relu_2', flatten_data='0')
  %70 = tvm_op(%66, %resnetv22_stage4_conv0_weight_OIHW8i8o, 
  %batch_norm15_gamma_mul_div_expand, %batch_norm15_add_beta_expand, 
  num_outputs='1', num_inputs='4', 
  func_name='fuse__contrib_conv2d_NCHWc_broadcast_mul_broadcast_add_relu_6', 
  flatten_data='0')
  %73 = tvm_op(%66, %resnetv22_stage4_conv2_weight_OI8i8oHW, num_outputs='1',
  num_inputs='2', func_name='fuse__contrib_conv2d_NCHWc_2', flatten_data='0')
  %74 = tvm_op(%70, %resnetv22_stage4_conv1_weight_OIHW8i8o, %73, num_outputs='1',
  num_inputs='3', func_name='fuse__contrib_conv2d_NCHWc_elemwise_add_3', 
  flatten_data='0')
  %77 = tvm_op(%74, %batch_norm16_gamma_mul_div_expand, 
  %batch_norm16_add_beta_expand, num_outputs='1', num_inputs='3', 
  func_name='fuse_broadcast_mul_broadcast_add_relu_3', flatten_data='0')
  %81 = tvm_op(%77, %resnetv22_stage4_conv3_weight_OIHW8i8o, 
  %batch_norm17_gamma_mul_div_expand, %batch_norm17_add_beta_expand,
  num_outputs='1', num_inputs='4', 
  func_name='fuse__contrib_conv2d_NCHWc_broadcast_mul_broadcast_add_relu_7',
  flatten_data='0')
  %85 = tvm_op(%81, %resnetv22_stage4_conv4_weight_OIHW8i8o, 
  %74, %batch_norm18_gamma_mul_div_expand, %batch_norm18_add_beta_expand, 
  num_outputs='1', num_inputs='5', 
 func_name='fuse__contrib_conv2d_NCHWc_elemwise_add_
 broadcast_mul_broadcast_add_relu_3', flatten_data='0')
  %86 = tvm_op(%85, num_outputs='1', num_inputs='1', 
  func_name='fuse_global_avg_pool2d', flatten_data='0')
  %87 = tvm_op(%86, num_outputs='1', num_inputs='1',
  func_name='fuse___layout_transform___reshape_flatten___mul_scalar__',
  flatten_data='0')
  %90 = tvm_op(%87, %resnetv22_dense0_weight, 
  %__mul_scalar__1, num_outputs='1', num_inputs='3', 
  func_name='fuse_dense', flatten_data='0')
  ret %90
}
graph_attr_keys = [storage_id, shape, dtype, dltype]