Source code for cogdl.experiments

import copy
import itertools
import os
from collections import defaultdict, namedtuple

import torch
import yaml
import optuna
from tabulate import tabulate

from cogdl.options import get_default_args
from cogdl.tasks import build_task
from cogdl.utils import set_random_seed, tabulate_results, init_operator_configs
from cogdl.configs import BEST_CONFIGS
from cogdl.datasets import SUPPORTED_DATASETS
from cogdl.models import SUPPORTED_MODELS

[docs]class AutoML(object): """ Args: func_search: function to obtain hyper-parameters to search """ def __init__(self, task, dataset, model, n_trials=3, **kwargs): self.task = task self.dataset = dataset self.model = model self.seed = kwargs.pop("seed") if "seed" in kwargs else [1] assert "func_search" in kwargs self.func_search = kwargs["func_search"] self.metric = kwargs["metric"] if "metric" in kwargs else None self.n_trials = n_trials self.best_value = None self.best_params = None self.default_params = kwargs def _objective(self, trials): params = self.default_params cur_params = self.func_search(trials) params.update(cur_params) result_dict = raw_experiment(task=self.task, dataset=self.dataset, model=self.model, seed=self.seed, **params) result_list = list(result_dict.values())[0] item = result_list[0] key = self.metric if key is None: for _key in item.keys(): if "Val" in _key or "val" in _key: key = _key break if key is None: raise KeyError("Unable to find validation metrics") val = [result[key] for result in result_list] mean = sum(val) / len(val) if self.best_value is None or mean > self.best_value: self.best_value = mean self.best_params = cur_params self.best_results = result_list return mean
[docs] def run(self): study = optuna.create_study(direction="maximize") study.optimize(self._objective, n_trials=self.n_trials, n_jobs=1) print(study.best_params) return self.best_results
[docs]def set_best_config(args): configs = BEST_CONFIGS[args.task] if args.model not in configs: return args configs = configs[args.model] for key, value in configs["general"].items(): args.__setattr__(key, value) if args.dataset not in configs: return args for key, value in configs[args.dataset].items(): args.__setattr__(key, value) return args
[docs]def train(args): if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id[0]) set_random_seed(args.seed) if getattr(args, "use_best_config", False): args = set_best_config(args) print(args) task = build_task(args) result = task.train() return result
[docs]def gen_variants(**items): Variant = namedtuple("Variant", items.keys()) return itertools.starmap(Variant, itertools.product(*items.values()))
[docs]def variant_args_generator(args, variants): """Form variants as group with size of num_workers""" for variant in variants: args.dataset, args.model, args.seed = variant yield copy.deepcopy(args)
[docs]def check_task_dataset_model_match(task, variants): match_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "match.yml") with open(match_path, "r", encoding="utf8") as f: match = yaml.load(f, Loader=yaml.FullLoader) objective = match.get(task, None) if objective is None: raise NotImplementedError pairs = [] for item in objective: pairs.extend([(x, y) for x in item["model"] for y in item["dataset"]]) clean_variants = [] for item in variants: if ( (item.dataset in SUPPORTED_DATASETS) and (item.model in SUPPORTED_MODELS) and (item.model, item.dataset) not in pairs ): print(f"({item.model}, {item.dataset}) is not implemented in task '{task}'.") continue clean_variants.append(item) if not clean_variants: exit(0) return clean_variants
[docs]def output_results(results_dict, tablefmt="github"): variant = list(results_dict.keys())[0] col_names = ["Variant"] + list(results_dict[variant][-1].keys()) tab_data = tabulate_results(results_dict) print(tabulate(tab_data, headers=col_names, tablefmt=tablefmt))
[docs]def raw_experiment(task: str, dataset, model, **kwargs): if "args" not in kwargs: args = get_default_args(task=task, dataset=dataset, model=model, **kwargs) else: args = kwargs["args"] init_operator_configs(args) variants = list(gen_variants(dataset=args.dataset, model=args.model, seed=args.seed)) variants = check_task_dataset_model_match(task, variants) results_dict = defaultdict(list) results = [train(args) for args in variant_args_generator(args, variants)] for variant, result in zip(variants, results): results_dict[variant[:-1]].append(result) tablefmt = kwargs["tablefmt"] if "tablefmt" in kwargs else "github" output_results(results_dict, tablefmt) return results_dict
[docs]def auto_experiment(task: str, dataset, model, **kwargs): variants = list(gen_variants(dataset=dataset, model=model)) variants = check_task_dataset_model_match(task, variants) results_dict = defaultdict(list) for variant in variants: tool = AutoML(task, variant.dataset, variant.model, **kwargs) results_dict[variant[:]] = tablefmt = kwargs["tablefmt"] if "tablefmt" in kwargs else "github" print("\nFinal results:\n") output_results(results_dict, tablefmt) return results_dict
[docs]def experiment(task: str, dataset, model, **kwargs): if "func_search" in kwargs: if isinstance(dataset, str): dataset = [dataset] if isinstance(model, str): model = [model] return auto_experiment(task, dataset, model, **kwargs) return raw_experiment(task, dataset, model, **kwargs)