31 lines
1.0 KiB
Python
31 lines
1.0 KiB
Python
|
def replace_layer_recursive(model, old_layer, new_layer):
|
||
|
for name, layer in model._modules.items():
|
||
|
if layer == old_layer:
|
||
|
model._modules[name] = new_layer
|
||
|
return True
|
||
|
elif replace_layer_recursive(layer, old_layer, new_layer):
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
|
||
|
def replace_all_layer_type_recursive(model, old_layer_type, new_layer):
|
||
|
for name, layer in model._modules.items():
|
||
|
if isinstance(layer, old_layer_type):
|
||
|
model._modules[name] = new_layer
|
||
|
replace_all_layer_type_recursive(layer, old_layer_type, new_layer)
|
||
|
|
||
|
|
||
|
def find_layer_types_recursive(model, layer_types):
|
||
|
def predicate(layer):
|
||
|
return type(layer) in layer_types
|
||
|
return find_layer_predicate_recursive(model, predicate)
|
||
|
|
||
|
|
||
|
def find_layer_predicate_recursive(model, predicate):
|
||
|
result = []
|
||
|
for name, layer in model._modules.items():
|
||
|
if predicate(layer):
|
||
|
result.append(layer)
|
||
|
result.extend(find_layer_predicate_recursive(layer, predicate))
|
||
|
return result
|