Source code for bigmodelvis.visualization

from typing import List
from rich.tree import Tree as RichTree
from rich import print as richprint
# import torch
import torch.nn as nn
import re
from collections import OrderedDict
from .vis_logging import get_logger
logger = get_logger(__name__)
class ModuleTree(RichTree):
    def __init__(
        self,
        module_name=None,
        info=None,
        is_param_node=False,
        type_color="green",
        param_color="red",
        main_color="white",
        style = "tree",
        guide_style = "tree.line",
        expanded=True,
        highlight=False,
        ):
        self.module_name = module_name
        self.info = info
        self.is_param_node = is_param_node
        self.type_color = type_color
        self.param_color = param_color
        self.main_color = main_color
        label = self.set_label()
        super().__init__(label,style=style,guide_style=guide_style,expanded=expanded,highlight=highlight)


    def add(
        self,
        module_name=None,
        info=None,
        is_param_node=False,
        type_color="green",
        param_color="red",
        main_color="white",
        style=None,
        guide_style=None,
        expanded=True,
        highlight=False,
        ):
        node = ModuleTree(
            module_name,
            info,
            is_param_node,
            type_color,
            param_color,
            main_color,
            style=self.style if style is None else style,
            guide_style=self.guide_style if guide_style is None else guide_style,
            expanded=expanded,
            highlight=self.highlight if highlight is None else highlight,
        )
        self.children.append(node)
        return node
    
    def set_label(self):
        if self.module_name is not None:
            label = f"[{self.main_color}]{self.module_name}"
        else:
            label = ""
        if self.info is not None:
            if not self.is_param_node:
                label += f" [{self.type_color}]({self.info})"
            else:
                label += f" [{self.param_color}]{self.info}"
        self.label = label
        return label


[docs]class Visualization(object): r""" Better visualization tool for *BIG* pretrained models. - Better repeated block representation - Clearer parameter position - and Visible parameter state. Args: plm (:obj:`torch.nn.Module`): The pretrained model, actually can be any pytorch module. """ def __init__(self, plm: nn.Module): self.plm = plm self.type_color = "green" self.param_color = "cyan" self.duplicate_color = "red" self.normal_color = "white" self.virtual_color = "orange" self.not_common_color = "bright_black" self.no_grad_color = "rgb(0,70,100)" self.delta_color = "rgb(175,0,255)" def check_mode(self, ): if self.keep_non_params and self.common_structure: raise RuntimeError("keep_non_params can't be used will common_structure. The common structure only contains parameter nodes.") if self.common_structure: if self.mapping is None: raise RuntimeError("Mapping hasn't been given.")
[docs] def structure_graph(self, rootname="root", expand_params=False, keep_non_params=False, common_structure=False, mapping=None, only_common=False, printTree=True, ): r"""Draw the structure graph in command line. Args: rootname (:obj:`str`) The root node's name. keep_non_params (:obj:`bool`) Display the modules that does not have parameters, such as nn.Dropout expand_params (:obj:`bool`) Display parameter infomation (shape, etc) in seperate lines. " common_structure (:obj:`bool`) Whether convert the structure into a common structure defined in structure_mapping.py. The not common structure will be displayed in grey. only_common (:obj:`bool`) Whether ignore the modules that are not in common structure. This will result in a more compact view. Default to False. mapping (:obj:`dict`) The structure mapping. Must provide if common_structure=True. """ self.keep_non_params = keep_non_params self.expand_params = expand_params self.rootname = rootname self.only_common = only_common self.common_structure = common_structure self.mapping = mapping self.check_mode() # root_tree = self.build_tree(rootname) self.root_tree = ModuleTree(self.rootname) if common_structure: self.build_common_tree(self.plm, mapping, self.root_tree) else: self.build_tree(self.plm, self.root_tree) self.prune_tree(self.root_tree) if not self.expand_params: self.fold_param_node(self.root_tree) if printTree: richprint(self.root_tree) return self.root_tree
def is_leaf_module(self, module): r"""[NODOC] Whether the module is a leaf module """ return len([n for n,_ in module.named_children()]) == 0 def build_tree(self, module:nn.Module, tree:ModuleTree=None): r"""[NODOC] build the originial tree structure """ if self.is_leaf_module(module): return else: for n,m in module.named_children(): type_info = re.search(r'(?<=\').*(?=\')', str(type(m))).group() type_info = type_info.split(".")[-1] newnode = tree.add(n, info=type_info, type_color=self.type_color) self.add_param_info_node(m, newnode) self.build_tree(module=m, tree=newnode) def has_parameter(self, module): return len([p for p in module.parameters()])>0
[docs] def build_common_tree(self, module:nn.Module, mapping, tree:ModuleTree=None, query="", key_to_root=""): r""" (Unstable) build the common tree structure """ if self.is_leaf_module(module): if len(query)>0: # the field is not in mapping if self.has_parameter(module): # from IPython import embed # embed(header = "in leaf") logger.warning(f"Parameter node {query} not found under tree {tree.module_name} and module {module}. Is your mapping correct?") # WARNING return else: for n,m in module.named_children(): new_query = query+n type_info = re.search(r'(?<=\').*(?=\')', str(type(m))).group() type_info = type_info.split(".")[-1] if new_query in mapping or "$" in mapping: # print("query",new_query) # from IPython import embed # embed() if new_query in mapping: new_mapping = mapping[new_query] name = new_mapping["__name__"] if len(name.split(".")) > 1: # new key contains a hierarchy , then unfold the hierarchy. # insert virtual node hierachical_name = name.split(".") temp_tree = self.find_or_insert(tree, hierachical_name) newnode = temp_tree.add(hierachical_name[-1], info=type_info, type_color=self.type_color) elif name=="": # the key not in a predefined common structure if self.only_common: continue else: # add the originial name into the tree newnode = tree.add(new_query, info=type_info, main_color=self.not_common_color, type_color=self.not_common_color) else: # a single new key newnode = self.find_not_insert(tree, [name,""]) # try to find the node if newnode is not None: newnode.info = type_info newnode.type_color = self.type_color newnode.set_label() else: newnode = tree.add(name, info=type_info, type_color=self.type_color) elif "$" in mapping: # match any thing in the field. new_mapping = mapping["$"] newnode = tree.add(n, info=type_info, type_color=self.type_color) self.add_param_info_node(m, newnode) self.build_common_tree(module=m, tree=newnode, mapping=new_mapping, key_to_root=key_to_root+"."+new_query) else: # try to find from root # trsf_key = transform(key_to_root.strip("."), self.mapping) # parent_node = self.find_not_insert(self.root_tree, trsf_key.split(".")+[""]) # if parent_node is not None: # new_mapping = mapping[new_query] # newnode = parent_node.add(name, info=type_info, type_color=self.type_color) # self.build_common_tree(module=m, tree=parent_node, mapping ) # print("notin query",new_query) # if new_query == "dense": # from IPython import embed # embed() # print(f"::{query},,{new_query}, {list(mapping.keys())}") new_query += "." self.build_common_tree(module=m, tree=tree, mapping=mapping, query=new_query, key_to_root=key_to_root)
def find_or_insert(self, tree:ModuleTree, hierachical_name:List[str] ): r"""[NODOC] Find the node, if not find, insert a virtual node """ if len(hierachical_name)==1: return tree names = [x.module_name for x in tree.children] if hierachical_name[0] not in names: new_node = tree.add(hierachical_name[0], info="Virtual", type_color=self.virtual_color) else: for x in tree.children: if x.module_name == hierachical_name[0]: new_node = x break return self.find_or_insert(new_node, hierachical_name=hierachical_name[1:]) def find_not_insert(self, tree:ModuleTree, hierachical_name:List[str] ): r"""[NODOC] Find the node but not insert """ if len(hierachical_name)==1: return tree names = [x.module_name for x in tree.children] if hierachical_name[0] not in names: return None else: for x in tree.children: if x.module_name == hierachical_name[0]: new_node = x break return self.find_not_insert(new_node, hierachical_name=hierachical_name[1:]) def fold_param_node(self, t: ModuleTree, p:ModuleTree=None): r"""[NODOC] place the parameters' infomation node right after the module that contains the parameters. E.g. w1 (Linear) -- weight: [32128, 1024] => w1 (Linear) weight: [32128, 1024] """ if hasattr(t,"is_param_node") and t.is_param_node: p.label += t.label return True # indicate whether should be removed elif len(t.children) == 0: if self.keep_non_params: return False else: return True else: rm_idx = [] for idx, c in enumerate(t.children): if self.fold_param_node(t=c, p=t): rm_idx.append(idx) t.children = [t.children[i] for i in range(len(t.children)) if i not in rm_idx] return False def prune_tree(self, t: ModuleTree): r"""[NODOC] Calculate the _finger_print of a module as the _finger_print of all child node plus the _finger_print of itself. The leaf node will have the _finger_print == label. Merge the different node that as the same _finger_print into a single node. """ if len(t.children) == 0: setattr(t, "_finger_print", t.label) return for idx, sub_tree in enumerate(t.children): self.prune_tree(sub_tree) t_finger_print = t.label +"::"+";".join([x._finger_print for x in t.children]) setattr(t, "_finger_print", t_finger_print) nohead_finger_print_dict = OrderedDict() for child_id, sub_tree in enumerate(t.children): fname_list = sub_tree._finger_print.split("::") if len(fname_list)==1: fname = fname_list[0] else: fname = "::".join(fname_list[1:]) if fname not in nohead_finger_print_dict: nohead_finger_print_dict[fname] = [child_id] else: nohead_finger_print_dict[fname].append(child_id) new_childrens = [] for groupname in nohead_finger_print_dict: representative_id = nohead_finger_print_dict[groupname][0] representative = t.children[representative_id] group_node = [t.children[idx] for idx in nohead_finger_print_dict[groupname]] representative = self.extract_common_and_join(group_node) new_childrens.append(representative) t.children = new_childrens def extract_common_and_join(self, l:List[ModuleTree]): r"""[NODOC] Some modules that have the same info (e.g., are all "Linear") have different names (e.g., w1,w2) Merge them. E.g. tree1.module_name = "w1", tree1.info = "Linear"; tree2.module_name = "w1", tree2.info = "Linear" -> representive.module_name = "w1,w2", representive.info = "Linear" """ representative = l[0] if len(l)==1: return representative name_list = [x.module_name for x in l] info_list = [x.info for x in l] type_hint_dict = OrderedDict() for x, y in zip(name_list, info_list): if y not in type_hint_dict: type_hint_dict[y] = [x] else: type_hint_dict[y].append(x) s = "" names = "" typeinfos = "" for t in type_hint_dict: group_components = type_hint_dict[t] group_components = self.neat_expr(group_components) names += group_components+"," typeinfos += t+"," s += f"[{self.duplicate_color}]{group_components}[{self.type_color}]({t})" s += f"," names = names[:-1] s = s[:-1] typeinfos = typeinfos[:-1] representative.module_name = names representative.type_info = typeinfos representative.label = s return representative def neat_expr(self, l:List[str]): r"""[NODOC] A small tool function to arrange the consecutive number into interval display. E.g., ["1","2","3","5","6","9","10","11","12"] -> ["1-3","5-6","9-12"] """ try: s = self.ranges([int(x.strip()) for x in l]) s = [str(x)+"-"+str(y) for x,y in s] return ",".join(s) except: return ",".join(l) def ranges(self, nums:List[int]): r"""[NODOC] A small tool function to arrange the consecutive number into interval display. E.g., [1,2,3,5,6,9,10,11,12] -> [[1,3],[5,6],[9,12]] """ nums = sorted(set(nums)) gaps = [[s, e] for s, e in zip(nums, nums[1:]) if s+1 < e] edges = iter(nums[:1] + sum(gaps, []) + nums[-1:]) return list(zip(edges, edges)) def add_param_info_node(self, m:nn.Module, tree:ModuleTree, record_grad_state=True, record_delta=True): r"""[NODOC] Add parameter infomation of the module. The parameters that are not inside a module (i.e., created using nn.Parameter) will be added in this function. """ known_module = [n for n,c in m.named_children()] try: for n,p in m.named_parameters(): if n.split(".")[0] not in known_module: if len(n.split(".")) > 1: raise RuntimeError(f"The name field {n} should be a parameter since it doesn't appear in named_children, but it contains '.'") info = "{}:{}".format(n, list(p.shape)) if record_grad_state: if not p.requires_grad: color = self.no_grad_color else: color = self.param_color else: color = self.param_color if record_delta: if hasattr(p, "_is_delta") and getattr(p, "_is_delta"): color = self.delta_color tree.add(info=info, is_param_node=True, param_color=color) except: from IPython import embed; embed(header='in vis')
if __name__=="__main__": # example command line: # 1. python opendelta/utils/visualization.py --model t5-lm --model_name_or_path t5-large-lm-adapt --common_structure --only_common # 2. python opendelta/utils/visualization.py --model roberta --model_name_or_path roberta-large --common_structure # 3. python opendelta/utils/visualization.py --model gpt2 --model_name_or_path gpt2-medium --keep_non_params --expand_params from openprompt.plms import load_plm import argparse parser = argparse.ArgumentParser("") parser.add_argument("--model", type=str, default='t5-lm', help="We test both t5 and t5-lm in this scripts, the corresponding tokenizerwrapper will be automatically loaded.") parser.add_argument("--model_name_or_path", default="t5-large") parser.add_argument("--cache_base", default='/home/hushengding/plm_cache/') parser.add_argument("--keep_non_params", action="store_true", help="Display the modules that does not have parameters, such as nn.Dropout") parser.add_argument("--expand_params", action="store_true", help="Display parameter infomation (shape, etc) in seperate lines. ") parser.add_argument("--common_structure", action="store_true", help="Whether convert the structure into a common structure defined in structure_mapping.py. The not common structure will be displayed in grey." ) parser.add_argument("--only_common", action="store_true", help="Whether ignore the modules that are not in common structure. This will result in a more compact view. Default to False") args = parser.parse_args() plm, tokenizer, model_config, WrapperClass = load_plm(args.model, args.cache_base+args.model_name_or_path) print("Model Loaded!") if args.common_structure: from opendelta.utils.structure_mapping import Mappings mapping = Mappings[args.model] else: mapping = None visobj = Visualization(plm) visobj.structure_graph(rootname=args.model_name_or_path, keep_non_params=args.keep_non_params, expand_params=args.expand_params, common_structure=args.common_structure, only_common=args.only_common, mapping=mapping)