import os
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import json
import random
from dyngpt._nnmodel._nnmodel import NNmodel
from dyngpt._utils._util_infer import *
from dyngpt._utils._util_train import *
from dyngpt._utils._util import *
import math
import time
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from dyngpt._nnmodel._nnmodel import NNmodel
from dyngpt._nnmodel._nnmodel import NNmodel_FT
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
[docs]def solving_state_transition_network(model_name,params=[],initial_value=[],args=""):
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
CUDA_LAUNCH_BLOCKING=1
weight_path = pkg_resources.resource_filename(__name__, 'weights/{}_final.pt'.format(args.model))
state = torch.load(weight_path)
net = NNmodel(**vars(args))
net.load_state_dict(state["net_state_dict"])
net.to(args.device)
net.eval()
params_prompt = np.hstack((params, np.tile(args.initial_value, (params.shape[0], 1))))
nn_sample_datas = nn_sample(params_prompt,net,args=args,param_indexs=range(0,params.shape[0],1),bs=args.batch_size,true_param=False)
return nn_sample_datas
[docs]def pre_training(config_pretrain):
"""Pre-train the neural network model for state transition network.
Args:
config_pretrain (Namespace): Configuration object containing training parameters.
Returns:
dict: A dictionary containing training loss means and the final model weight path.
"""
args = config_pretrain
dataset_path = args.train_dataset_path
args.out_dir = args.result_dir
args.epochs = args.epochs_pretrain
device = args.device
net = NNmodel(**vars(args))
net.to(device)
optimizer = net.configure_optimizers(
args.weight_decay,
args.lr,
(args.beta1, args.beta2),
device_type=args.device
)
if args.load_weight:
model_state = torch.load(args.weight_file_path)
net.load_state_dict(model_state["net_state_dict"])
optimizer.load_state_dict(model_state["optimizer_state_dict"])
raw_model = net
ensure_dir(args.out_dir +"weights/" )
ensure_dir(args.out_dir + "train_state/")
ensure_dir(args.out_dir + "train_loss/")
loss_means = []
loss_stds = []
prompt_prob_datasets = load_promt_prob(args,file_path = dataset_path)
simulation_num = sum(prompt_prob_datasets[0][1].values())
for epoch in range(args.start_epoch, args.epochs):
data_size = 100
batch_size = args.batch_size
lr = get_lr(epoch) if args.decay_lr else args.lr
set_learning_rate(optimizer, lr)
loss_std_collection = []
optimizer.zero_grad() # accumulate weights for several prompts with small batch
random.shuffle(prompt_prob_datasets)
for i in range(data_size):
prompt_prob = prompt_prob_datasets[random.randint(0,len(prompt_prob_datasets)-2)]
with torch.no_grad():
prompt = prompt_prob[0]
if torch.isinf(torch.tensor(prompt)).any():
continue
if args.weighted_sample:
sample_weight = np.array(list(prompt_prob[1].values()))
sample_weight = sample_weight/sum(sample_weight)
samples_str = np.random.choice(list(prompt_prob[1].keys()),args.batch_size,p=sample_weight)
else:
samples_str = np.random.choice(list(prompt_prob[1].keys()),batch_size)
samples_int = np.array([el.split("_") for el in samples_str ]).astype(int)
# Truncate samples using the maximum value
for c_index in range(len(args.constrains)):
samples_int[samples_int[:,c_index] >= args.constrains[c_index],c_index] = args.constrains[c_index]-1
prompt = np.repeat(np.array(prompt)[np.newaxis, :], batch_size, axis=0)
pro_sam = np.hstack([prompt,samples_int])
pro_sam = torch.as_tensor(pro_sam, dtype=args.default_dtype_torch, device=device)
y_label = torch.as_tensor(samples_int, dtype=torch.long, device=device)
loss_cross = net(pro_sam,y_label)
loss_std_collection.append(loss_cross.detach().cpu().numpy())
optimizer.zero_grad(set_to_none=True)
loss_cross.backward()
if args.clip_grad:
nn.utils.clip_grad_norm_(net.parameters(), args.clip_grad)
optimizer.step()
ls = np.mean(loss_std_collection)
loss_means.append(ls)
print(
f"|epoch {epoch}| cross loss: {ls:.4f}, lr: {optimizer.param_groups[0]['lr']:.8f}")
if epoch % 100 == 0:
with torch.no_grad():
weight_path = f"{args.out_dir}/weights/{args.model}_epoch{epoch}.pt"
torch.save({
# "net_state_dict": net.state_dict(),
"net_state_dict": raw_model.state_dict(),
"dtype": args.dtype,
"optimizer_state_dict": optimizer.state_dict(),
}, weight_path)
loss_data_path = f"{args.out_dir}/train_loss/{args.model}_epoch{epoch}.npz"
np.savez(
loss_data_path,
loss_mean=np.array(loss_means),
loss_std=np.array(loss_stds)
)
figure_path = f"{args.out_dir}/train_loss/{args.model}_epoch{epoch}.png"
weight_path = f"{args.out_dir}/weights/{args.model}_pretrain_final.pt"
torch.save({
"net_state_dict": raw_model.state_dict(),
"dtype": args.dtype,
"optimizer_state_dict": optimizer.state_dict(),
}, weight_path)
result_dict = {}
result_dict["loss_means"] = loss_means
result_dict["weight_nn_pretrain_path"] = weight_path
return result_dict
[docs]def fine_tune_training(config_fine_tune):
"""Fine-tune the neural network model using a specified training configuration.
Args:
config_fine_tune (Namespace): Configuration object containing hyperparameters,
dataset paths, training settings, and model specifications.
Raises:
ValueError: If an unknown loss type is specified in `args.loss_type`.
Returns:
dict: A dictionary containing training results, including loss means and the path
to the fine-tuned model weights.
"""
# start iterating
args = config_fine_tune
dataset_path = args.train_dataset_path
args.load_weight = True
args.out_dir = args.result_dir
args.epochs = args.epochs_fine_tune
weight_file_path = args.weight_nn_pretrain_path
args.sample = True
args.ssa_sample = True
args.sample_num = 500
args.ssa_sample_num = 500
ensure_dir(args.out_dir +"weights/" )
ensure_dir(args.out_dir + "train_state/")
ensure_dir(args.out_dir + "train_loss/")
device = args.device
net = NNmodel_FT(**vars(args))
net.to(device)
optimizer = net.configure_optimizers(
args.weight_decay,
args.lr,
(args.beta1, args.beta2),
device_type=args.device
)
if args.load_weight:
model_state = torch.load(weight_file_path)
net.load_state_dict(model_state["net_state_dict"])
optimizer.load_state_dict(model_state["optimizer_state_dict"])
loss_means = []
loss_stds = []
eds = []
kls = []
# load datasets
prompt_prob_datasets = load_promt_prob(args,file_path = dataset_path)
simulation_num = sum(prompt_prob_datasets[0][1].values())
for epoch in range(args.start_epoch, args.epochs):
lr = get_lr(epoch,learning_rate = 8e-4,lr_decay_iters= 10000,min_lr = 6e-5) if args.decay_lr else args.lr
set_learning_rate(optimizer, lr)
loss_mean_collection,loss_std_collection = [],[]
kl_distance_collection,euclidean_distance_collection = [],[]
optimizer.zero_grad() # accumulate weights for several prompts with small batch
random.shuffle(prompt_prob_datasets)
data_size = 100
for i in range(data_size):
with torch.no_grad():
prompt_prob = prompt_prob_datasets[i]
if args.sample:
with torch.no_grad():
prompt = torch.as_tensor(prompt_prob[0], dtype=args.default_dtype_torch, device=args.device)
if torch.isinf(prompt).any():
continue
prompt = prompt.repeat(args.batch_size, 1)
pro_sam, samples = net.sample(prompt) # pro_sam: prompt + samples as one complete sentence
samples = samples.detach()
log_Tp_t = [prompt_prob[1].get('_'.join(list(map(str,sample.int().tolist()))), args.epsilon2) for sample in samples]
log_Tp_t = np.log(np.array(log_Tp_t)/simulation_num)
log_Tp_t = torch.as_tensor(log_Tp_t, dtype=args.default_dtype_torch, device=args.device)
if args.ssa_sample and args.sample:
pro_sam_nn=pro_sam
log_Tp_t_nn = log_Tp_t
if args.ssa_sample:
prompt = prompt_prob[0]
samples_str = np.random.choice(list(prompt_prob[1].keys()),args.batch_size)
samples_int = np.array([el.split("_") for el in samples_str ]).astype(int)
for c_index in range(len(args.constrains)):
samples_int[samples_int[:,c_index] >= args.constrains[c_index],c_index] = args.constrains[c_index]-1
prompt = np.repeat(np.array(prompt)[np.newaxis, :], args.batch_size, axis=0)
pro_sam = np.hstack([prompt,samples_int])
pro_sam = torch.as_tensor(pro_sam, dtype=args.default_dtype_torch, device=args.device)
log_Tp_t = [prompt_prob[1].get(sample, args.epsilon) for sample in samples_str]
log_Tp_t = np.log(np.array(log_Tp_t)/simulation_num)
log_Tp_t = torch.as_tensor(log_Tp_t, dtype=args.default_dtype_torch, device=args.device)
if args.ssa_sample and args.sample:
pro_sam = torch.cat([pro_sam_nn[:args.sample_num,:],pro_sam[:args.ssa_sample_num,:]])
log_Tp_t = torch.cat([log_Tp_t_nn[:args.sample_num ],log_Tp_t[:args.ssa_sample_num]])
log_prob = net.log_joint_prob(pro_sam,unobserved_number = 0)
with torch.no_grad():
prob = torch.exp(log_prob.detach())
r_prob = torch.exp(log_Tp_t.detach())
r_prob = r_prob / r_prob.sum()
r_prob = (r_prob * prob.sum()).detach()
loss = log_prob - log_Tp_t.detach()
loss_l2 = prob - r_prob
loss_he = -torch.sqrt(prob * r_prob)
assert not log_Tp_t.requires_grad
if args.loss_type == 'kl':
loss_reinforce = torch.mean((loss - loss.mean()) * log_prob)
elif args.loss_type == 'klreweight':
loss3 = prob * loss / prob.mean()
loss_reinforce = torch.mean((loss3 - loss3.mean()) * log_prob)
elif args.loss_type == 'l2':
loss_reinforce = torch.mean((loss_l2 - loss_l2.mean()) * log_prob)
elif args.loss_type == "crossEntropy":
loss_reinforce = -torch.sum(r_prob * log_prob)
else:
raise ValueError('Unknown loss type: {}'.format(args.loss_type))
loss_reinforce.backward()
loss_std = loss.std()
loss_mean = loss.mean()
loss_mean_collection.append(loss_mean.detach().cpu().numpy())
loss_std_collection.append(loss_std.detach().cpu().numpy())
f = (prob - r_prob) ** 2
euclidean_distance = torch.sqrt(torch.sum(f))
kl_distance = torch.nn.functional.kl_div(
log_prob,
r_prob,
None,
None,
'sum'
)
euclidean_distance_collection.append(euclidean_distance.detach().cpu().numpy())
kl_distance_collection.append(kl_distance.detach().cpu().numpy())
if args.clip_grad:
nn.utils.clip_grad_norm_(net.parameters(), args.clip_grad)
optimizer.step()
lm = np.mean(loss_mean_collection)
ls = np.mean(loss_std_collection)
ed = np.mean(euclidean_distance_collection)
kl = np.mean(kl_distance_collection)
loss_means.append(lm)
loss_stds.append(ls)
eds.append(ed)
kls.append(kl)
print(
f"|epoch {epoch}|mean loss: {lm:.4f}, std loss: {ls:.4f}, ed: {ed:.4f}, kd: {kl:.4f}, lr: {optimizer.param_groups[0]['lr']:.8f}")
if epoch % 100 == 0:
with torch.no_grad():
path = f"{args.out_dir}/weights/fine_tune_{args.model}_epoch{epoch}.pt"
torch.save({
"net_state_dict": net.state_dict(),
"dtype": args.dtype,
"optimizer_state_dict": optimizer.state_dict(),
}, path)
path1 = f"{args.out_dir}/train_loss/fine_tune_{args.model}_epoch{epoch}.npz"
np.savez(
path1,
loss_mean=np.array(loss_means),
loss_std=np.array(loss_stds),
ed=np.array(eds),
kl=np.array(kls),
)
weight_path = f"{args.out_dir}/weights/fine_tune_{args.model}_final.pt"
torch.save({
"net_state_dict": net.state_dict(),
"dtype": args.dtype,
"optimizer_state_dict": optimizer.state_dict(),
}, weight_path)
result_dict = {}
result_dict["loss_means"] = loss_means
result_dict["weight_nn_fine_tune_path"] = weight_path
return result_dict
[docs]def solving_STN(model_name,params=[],initial_value=[],args=""):
"""Solve the state transition network using a trained neural network model.
Args:
model_name (str): The name of the model to be used.
params (list, optional): A list of parameters for the model. Defaults to [].
initial_value (list, optional): A list of initial values for the system. Defaults to [].
args (Namespace, optional): Additional arguments containing model configurations,
including device, batch size, and initial values. Defaults to "".
Returns:
np.ndarray: Sampled data generated by the neural network based on the provided parameters.
"""
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
CUDA_LAUNCH_BLOCKING=1
weight_path = pkg_resources.resource_filename(__name__, 'weights/{}_final.pt'.format(args.model))
state = torch.load(weight_path)
net = NNmodel(**vars(args))
net.load_state_dict(state["net_state_dict"])
net.to(args.device)
net.eval()
params_prompt = np.hstack((params, np.tile(args.initial_value, (params.shape[0], 1))))
nn_sample_datas = nn_sample(params_prompt,net,args=args,param_indexs=range(0,params.shape[0],1),bs=args.batch_size,true_param=False)
return nn_sample_datas