from typing import List
from rich.tree import Tree as RichTree
from rich import print as richprint
import torch
import torch.nn as nn
import re
from collections import OrderedDict
import opendelta.utils.logging as logging
logger = logging.get_logger(__name__)
class ModuleTree(RichTree):
def __init__(
self,
module_name=None,
info=None,
is_param_node=False,
type_color="green",
param_color="red",
main_color="white",
style = "tree",
guide_style = "tree.line",
expanded=True,
highlight=False,
):
self.module_name = module_name
self.info = info
self.is_param_node = is_param_node
self.type_color = type_color
self.param_color = param_color
self.main_color = main_color
label = self.set_label()
super().__init__(label,style=style,guide_style=guide_style,expanded=expanded,highlight=highlight)
def add(
self,
module_name=None,
info=None,
is_param_node=False,
type_color="green",
param_color="red",
main_color="white",
style=None,
guide_style=None,
expanded=True,
highlight=False,
):
node = ModuleTree(
module_name,
info,
is_param_node,
type_color,
param_color,
main_color,
style=self.style if style is None else style,
guide_style=self.guide_style if guide_style is None else guide_style,
expanded=expanded,
highlight=self.highlight if highlight is None else highlight,
)
self.children.append(node)
return node
def set_label(self):
if self.module_name is not None:
label = f"[{self.main_color}]{self.module_name}"
else:
label = ""
if self.info is not None:
if not self.is_param_node:
label += f" [{self.type_color}]({self.info})"
else:
label += f" [{self.param_color}]{self.info}"
self.label = label
return label
[docs]class Visualization(object):
r"""
Better visualization tool for *BIG* pretrained models.
- Better repeated block representation
- Clearer parameter position
- and Visible parameter state.
Args:
plm (:obj:`torch.nn.Module`): The pretrained model, actually can be any pytorch module.
"""
def __init__(self, plm: nn.Module):
self.plm = plm
self.type_color = "green"
self.param_color = "cyan"
self.duplicate_color = "red"
self.normal_color = "white"
self.virtual_color = "orange"
self.not_common_color = "bright_black"
self.no_grad_color = "rgb(0,70,100)"
self.delta_color = "rgb(175,0,255)"
def check_mode(self, ):
if self.keep_non_params and self.common_structure:
raise RuntimeError("keep_non_params can't be used will common_structure. The common structure only contains parameter nodes.")
if self.common_structure:
if self.mapping is None:
raise RuntimeError("Mapping hasn't been given.")
[docs] def structure_graph(self,
rootname="root",
expand_params=False,
keep_non_params=False,
common_structure=False,
mapping=None,
only_common=False,
printTree=True,
):
r"""Draw the structure graph in command line.
Args:
rootname (:obj:`str`) The root node's name.
keep_non_params (:obj:`bool`) Display the modules that does not have parameters, such as nn.Dropout
expand_params (:obj:`bool`) Display parameter infomation (shape, etc) in seperate lines. "
common_structure (:obj:`bool`) Whether convert the structure into a common structure defined in structure_mapping.py. The not common structure will be displayed in grey.
only_common (:obj:`bool`) Whether ignore the modules that are not in common structure. This will result in a more compact view. Default to False.
mapping (:obj:`dict`) The structure mapping. Must provide if common_structure=True.
"""
self.keep_non_params = keep_non_params
self.expand_params = expand_params
self.rootname = rootname
self.only_common = only_common
self.common_structure = common_structure
self.mapping = mapping
self.check_mode()
# root_tree = self.build_tree(rootname)
self.root_tree = ModuleTree(self.rootname)
if common_structure:
self.build_common_tree(self.plm, mapping, self.root_tree)
else:
self.build_tree(self.plm, self.root_tree)
self.prune_tree(self.root_tree)
if not self.expand_params:
self.fold_param_node(self.root_tree)
if printTree:
richprint(self.root_tree)
return self.root_tree
def is_leaf_module(self, module):
r"""[NODOC] Whether the module is a leaf module
"""
return len([n for n,_ in module.named_children()]) == 0
def build_tree(self, module:nn.Module, tree:ModuleTree=None):
r"""[NODOC] build the originial tree structure
"""
if self.is_leaf_module(module):
return
else:
for n,m in module.named_children():
type_info = re.search(r'(?<=\').*(?=\')', str(type(m))).group()
type_info = type_info.split(".")[-1]
newnode = tree.add(n, info=type_info, type_color=self.type_color)
self.add_param_info_node(m, newnode)
self.build_tree(module=m, tree=newnode)
def has_parameter(self, module):
return len([p for p in module.parameters()])>0
[docs] def build_common_tree(self, module:nn.Module, mapping, tree:ModuleTree=None, query="", key_to_root=""):
r""" (Unstable) build the common tree structure
"""
if self.is_leaf_module(module):
if len(query)>0: # the field is not in mapping
if self.has_parameter(module):
# from IPython import embed
# embed(header = "in leaf")
logger.warning(f"Parameter node {query} not found under tree {tree.module_name} and module {module}. Is your mapping correct?") # WARNING
return
else:
for n,m in module.named_children():
new_query = query+n
type_info = re.search(r'(?<=\').*(?=\')', str(type(m))).group()
type_info = type_info.split(".")[-1]
if new_query in mapping or "$" in mapping:
# print("query",new_query)
# from IPython import embed
# embed()
if new_query in mapping:
new_mapping = mapping[new_query]
name = new_mapping["__name__"]
if len(name.split(".")) > 1: # new key contains a hierarchy , then unfold the hierarchy.
# insert virtual node
hierachical_name = name.split(".")
temp_tree = self.find_or_insert(tree, hierachical_name)
newnode = temp_tree.add(hierachical_name[-1], info=type_info, type_color=self.type_color)
elif name=="": # the key not in a predefined common structure
if self.only_common:
continue
else: # add the originial name into the tree
newnode = tree.add(new_query, info=type_info, main_color=self.not_common_color, type_color=self.not_common_color)
else: # a single new key
newnode = self.find_not_insert(tree, [name,""]) # try to find the node
if newnode is not None:
newnode.info = type_info
newnode.type_color = self.type_color
newnode.set_label()
else:
newnode = tree.add(name, info=type_info, type_color=self.type_color)
elif "$" in mapping: # match any thing in the field.
new_mapping = mapping["$"]
newnode = tree.add(n, info=type_info, type_color=self.type_color)
self.add_param_info_node(m, newnode)
self.build_common_tree(module=m, tree=newnode, mapping=new_mapping, key_to_root=key_to_root+"."+new_query)
else:
# try to find from root
# trsf_key = transform(key_to_root.strip("."), self.mapping)
# parent_node = self.find_not_insert(self.root_tree, trsf_key.split(".")+[""])
# if parent_node is not None:
# new_mapping = mapping[new_query]
# newnode = parent_node.add(name, info=type_info, type_color=self.type_color)
# self.build_common_tree(module=m, tree=parent_node, mapping )
# print("notin query",new_query)
# if new_query == "dense":
# from IPython import embed
# embed()
# print(f"::{query},,{new_query}, {list(mapping.keys())}")
new_query += "."
self.build_common_tree(module=m, tree=tree, mapping=mapping, query=new_query, key_to_root=key_to_root)
def find_or_insert(self, tree:ModuleTree, hierachical_name:List[str] ):
r"""[NODOC] Find the node, if not find, insert a virtual node
"""
if len(hierachical_name)==1:
return tree
names = [x.module_name for x in tree.children]
if hierachical_name[0] not in names:
new_node = tree.add(hierachical_name[0], info="Virtual", type_color=self.virtual_color)
else:
for x in tree.children:
if x.module_name == hierachical_name[0]:
new_node = x
break
return self.find_or_insert(new_node, hierachical_name=hierachical_name[1:])
def find_not_insert(self, tree:ModuleTree, hierachical_name:List[str] ):
r"""[NODOC] Find the node but not insert
"""
if len(hierachical_name)==1:
return tree
names = [x.module_name for x in tree.children]
if hierachical_name[0] not in names:
return None
else:
for x in tree.children:
if x.module_name == hierachical_name[0]:
new_node = x
break
return self.find_not_insert(new_node, hierachical_name=hierachical_name[1:])
def fold_param_node(self, t: ModuleTree, p:ModuleTree=None):
r"""[NODOC] place the parameters' infomation node right after the module that contains the parameters.
E.g. w1 (Linear)
-- weight: [32128, 1024]
=>
w1 (Linear) weight: [32128, 1024]
"""
if hasattr(t,"is_param_node") and t.is_param_node:
p.label += t.label
return True # indicate whether should be removed
elif len(t.children) == 0:
if self.keep_non_params:
return False
else:
return True
else:
rm_idx = []
for idx, c in enumerate(t.children):
if self.fold_param_node(t=c, p=t):
rm_idx.append(idx)
t.children = [t.children[i] for i in range(len(t.children)) if i not in rm_idx]
return False
def prune_tree(self, t: ModuleTree):
r"""[NODOC] Calculate the _finger_print of a module as the _finger_print of all child node plus the _finger_print of itself.
The leaf node will have the _finger_print == label.
Merge the different node that as the same _finger_print into a single node.
"""
if len(t.children) == 0:
setattr(t, "_finger_print", t.label)
return
for idx, sub_tree in enumerate(t.children):
self.prune_tree(sub_tree)
t_finger_print = t.label +"::"+";".join([x._finger_print for x in t.children])
setattr(t, "_finger_print", t_finger_print)
nohead_finger_print_dict = OrderedDict()
for child_id, sub_tree in enumerate(t.children):
fname_list = sub_tree._finger_print.split("::")
if len(fname_list)==1:
fname = fname_list[0]
else:
fname = "::".join(fname_list[1:])
if fname not in nohead_finger_print_dict:
nohead_finger_print_dict[fname] = [child_id]
else:
nohead_finger_print_dict[fname].append(child_id)
new_childrens = []
for groupname in nohead_finger_print_dict:
representative_id = nohead_finger_print_dict[groupname][0]
representative = t.children[representative_id]
group_node = [t.children[idx] for idx in nohead_finger_print_dict[groupname]]
representative = self.extract_common_and_join(group_node)
new_childrens.append(representative)
t.children = new_childrens
def extract_common_and_join(self, l:List[ModuleTree]):
r"""[NODOC] Some modules that have the same info (e.g., are all "Linear") have different names (e.g., w1,w2)
Merge them.
E.g. tree1.module_name = "w1", tree1.info = "Linear"; tree2.module_name = "w1", tree2.info = "Linear"
-> representive.module_name = "w1,w2", representive.info = "Linear"
"""
representative = l[0]
if len(l)==1:
return representative
name_list = [x.module_name for x in l]
info_list = [x.info for x in l]
type_hint_dict = OrderedDict()
for x, y in zip(name_list, info_list):
if y not in type_hint_dict:
type_hint_dict[y] = [x]
else:
type_hint_dict[y].append(x)
s = ""
names = ""
typeinfos = ""
for t in type_hint_dict:
group_components = type_hint_dict[t]
group_components = self.neat_expr(group_components)
names += group_components+","
typeinfos += t+","
s += f"[{self.duplicate_color}]{group_components}[{self.type_color}]({t})"
s += f","
names = names[:-1]
s = s[:-1]
typeinfos = typeinfos[:-1]
representative.module_name = names
representative.type_info = typeinfos
representative.label = s
return representative
def neat_expr(self, l:List[str]):
r"""[NODOC] A small tool function to arrange the consecutive number into interval display.
E.g., ["1","2","3","5","6","9","10","11","12"] -> ["1-3","5-6","9-12"]
"""
try:
s = self.ranges([int(x.strip()) for x in l])
s = [str(x)+"-"+str(y) for x,y in s]
return ",".join(s)
except:
return ",".join(l)
def ranges(self, nums:List[int]):
r"""[NODOC] A small tool function to arrange the consecutive number into interval display.
E.g., [1,2,3,5,6,9,10,11,12] -> [[1,3],[5,6],[9,12]]
"""
nums = sorted(set(nums))
gaps = [[s, e] for s, e in zip(nums, nums[1:]) if s+1 < e]
edges = iter(nums[:1] + sum(gaps, []) + nums[-1:])
return list(zip(edges, edges))
def add_param_info_node(self, m:nn.Module, tree:ModuleTree, record_grad_state=True, record_delta=True):
r"""[NODOC] Add parameter infomation of the module. The parameters that are not inside a module (i.e., created using nn.Parameter) will be added in this function.
"""
known_module = [n for n,c in m.named_children()]
try:
for n,p in m.named_parameters():
if n.split(".")[0] not in known_module:
if len(n.split(".")) > 1: raise RuntimeError(f"The name field {n} should be a parameter since it doesn't appear in named_children, but it contains '.'")
info = "{}:{}".format(n, list(p.shape))
if record_grad_state:
if not p.requires_grad:
color = self.no_grad_color
else:
color = self.param_color
else:
color = self.param_color
if record_delta:
if hasattr(p, "_is_delta") and getattr(p, "_is_delta"):
color = self.delta_color
tree.add(info=info, is_param_node=True, param_color=color)
except:
from IPython import embed; embed(header='in vis')
if __name__=="__main__":
# example command line:
# 1. python opendelta/utils/visualization.py --model t5-lm --model_name_or_path t5-large-lm-adapt --common_structure --only_common
# 2. python opendelta/utils/visualization.py --model roberta --model_name_or_path roberta-large --common_structure
# 3. python opendelta/utils/visualization.py --model gpt2 --model_name_or_path gpt2-medium --keep_non_params --expand_params
from openprompt.plms import load_plm
import argparse
parser = argparse.ArgumentParser("")
parser.add_argument("--model", type=str, default='t5-lm', help="We test both t5 and t5-lm in this scripts, the corresponding tokenizerwrapper will be automatically loaded.")
parser.add_argument("--model_name_or_path", default="t5-large-lm-adapt")
parser.add_argument("--cache_base", default='/home/hushengding/plm_cache/')
parser.add_argument("--keep_non_params", action="store_true", help="Display the modules that does not have parameters, such as nn.Dropout")
parser.add_argument("--expand_params", action="store_true", help="Display parameter infomation (shape, etc) in seperate lines. ")
parser.add_argument("--common_structure", action="store_true", help="Whether convert the structure into a common structure defined in structure_mapping.py. The not common structure will be displayed in grey." )
parser.add_argument("--only_common", action="store_true", help="Whether ignore the modules that are not in common structure. This will result in a more compact view. Default to False")
args = parser.parse_args()
plm, tokenizer, model_config, WrapperClass = load_plm(args.model, args.cache_base+args.model_name_or_path)
print("Model Loaded!")
if args.common_structure:
from opendelta.utils.structure_mapping import Mappings
mapping = Mappings[args.model]
else:
mapping = None
visobj = Visualization(plm)
visobj.structure_graph(rootname=args.model_name_or_path, keep_non_params=args.keep_non_params, expand_params=args.expand_params, common_structure=args.common_structure, only_common=args.only_common, mapping=mapping)