Function parameter registration method.
It registers the parameters involved in some data processing functions, e.g., BatchNorm1d.
Parameters:
Name |
Type |
Description |
Default |
model |
|
|
required
|
func_list |
|
The list of data processing functions.
|
required
|
Returns:
Type |
Description |
None
|
This functon doesn't have any return values.
|
Source code in tinybig/util/util.py
| def register_function_parameters(model, func_list):
"""
Function parameter registration method.
It registers the parameters involved in some data processing functions, e.g., BatchNorm1d.
Parameters
----------
model: Any
The rpn model.
func_list: list
The list of data processing functions.
Returns
-------
None
This functon doesn't have any return values.
"""
if not hasattr(func_list, '__iter__'):
func_list = [func_list]
for idx, function in enumerate(func_list):
if hasattr(function, 'weight') and hasattr(function, 'bias'):
model.register_parameter(f'layer_{idx}_weight', function.weight)
model.register_parameter(f'layer_{idx}_bias', function.bias)
if hasattr(function, 'running_mean') and hasattr(function, 'running_var'):
model.register_buffer(f'layer_{idx}_running_mean', function.running_mean)
model.register_buffer(f'layer_{idx}_running_var', function.running_var)
|