import sys
import argparse
from cogdl.datasets import DATASET_REGISTRY, try_import_dataset
from cogdl.models import MODEL_REGISTRY, try_import_model
from cogdl.tasks import TASK_REGISTRY
from cogdl.trainers import TRAINER_REGISTRY, try_import_trainer
[docs]def get_parser():
parser = argparse.ArgumentParser(conflict_handler="resolve")
# fmt: off
# parser.add_argument('--log-interval', type=int, default=1000, metavar='N',
# help='log progress every N batches (when progress bar is disabled)')
# parser.add_argument('--tensorboard-logdir', metavar='DIR', default='',
# help='path to save logs for tensorboard, should match --logdir '
# 'of running tensorboard (default: no tensorboard logging)')
parser.add_argument('--seed', default=[1], type=int, nargs='+', metavar='N',
help='pseudo random number generator seed')
parser.add_argument('--max-epoch', default=500, type=int)
parser.add_argument('--patience', type=int, default=100)
parser.add_argument('--lr', default=0.01, type=float)
parser.add_argument('--weight-decay', default=5e-4, type=float)
parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA')
parser.add_argument('--device-id', default=[0], type=int, nargs='+',
help='which GPU to use')
parser.add_argument('--save-dir', default='.', type=str)
parser.add_argument('--checkpoint', action="store_true", help='load pre-trained model')
parser.add_argument('--use-best-config', action='store_true', help='use best config')
# fmt: on
return parser
[docs]def add_task_args(parser):
group = parser.add_argument_group("Task configuration")
# fmt: off
group.add_argument('--task', '-t', default='node_classification', metavar='TASK', required=True,
choices=TASK_REGISTRY.keys(),
help='Task')
# fmt: on
return group
[docs]def add_dataset_args(parser):
group = parser.add_argument_group("Dataset and data loading")
# fmt: off
group.add_argument('--dataset', '-dt', metavar='DATASET', nargs='+', required=True,
help='Dataset')
# fmt: on
return group
[docs]def add_model_args(parser):
group = parser.add_argument_group("Model configuration")
# fmt: off
group.add_argument('--model', '-m', metavar='MODEL', nargs='+', required=True,
help='Model Architecture')
group.add_argument('--fast-spmm', action="store_true", required=False,
help='whether to use gespmm')
# fmt: on
return group
[docs]def add_trainer_args(parser):
group = parser.add_argument_group("Trainer configuration")
# fmt: off
group.add_argument('--trainer', metavar='TRAINER', required=False,
help='Trainer')
group.add_argument('--eval-step', type=int, default=1)
# fmt: on
return group
[docs]def get_training_parser():
parser = get_parser()
add_task_args(parser)
add_dataset_args(parser)
add_model_args(parser)
add_trainer_args(parser)
return parser
[docs]def get_display_data_parser():
parser = get_parser()
add_dataset_args(parser)
parser.add_argument("--depth", default=3, type=int)
return parser
[docs]def get_download_data_parser():
parser = get_parser()
add_dataset_args(parser)
return parser
[docs]def get_default_args(task: str, dataset, model, **kwargs):
if not isinstance(dataset, list):
dataset = [dataset]
if not isinstance(model, list):
model = [model]
sys.argv = [sys.argv[0], "-t", task, "-m"] + model + ["-dt"] + dataset
parser = get_training_parser()
args, _ = parser.parse_known_args()
args = parse_args_and_arch(parser, args)
for key, value in kwargs.items():
args.__setattr__(key, value)
return args
[docs]def parse_args_and_arch(parser, args):
"""The parser doesn't know about model-specific args, so we parse twice."""
# args, _ = parser.parse_known_args()
# Add *-specific args to parser.
TASK_REGISTRY[args.task].add_args(parser)
for model in args.model:
if try_import_model(model):
MODEL_REGISTRY[model].add_args(parser)
for dataset in args.dataset:
if try_import_dataset(dataset):
if hasattr(DATASET_REGISTRY[dataset], "add_args"):
DATASET_REGISTRY[dataset].add_args(parser)
if "trainer" in args and args.trainer is not None:
# for trainer in args.trainer:
if try_import_trainer(args.trainer):
if hasattr(TRAINER_REGISTRY[args.trainer], "add_args"):
TRAINER_REGISTRY[args.trainer].add_args(parser)
# Parse a second time.
args = parser.parse_args()
return args
[docs]def get_task_model_args(task, model=None):
sys.argv = [sys.argv[0], "-t", task, "-m"] + ["gcn"] + ["-dt"] + ["cora"]
parser = get_training_parser()
TASK_REGISTRY[task].add_args(parser)
if model is not None:
if try_import_model(model):
MODEL_REGISTRY[model].add_args(parser)
args = parser.parse_args()
args.task = task
if model is not None:
args.model = model
return args