import os
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import json
import gc
import random
import time
import torch.optim as optim
import torch.nn as nn
from sklearn.neighbors import KernelDensity
from dyngpt._nnmodel._nnmodel import NNmodel,NNmodel_FT
from dyngpt._utils._util import *
from dyngpt._utils._util_infer import *
from dyngpt._utils._util_plotting import *
from dyngpt.plotting._plotting import plot_afl_inference_data_burst_distribution
from dyngpt.plotting._plotting import plot_nm_nm_inference_data_burst_distribution
from dyngpt.plotting._plotting import plot_isc_compare_distribution
[docs]def inferring_dynamics(observed_data,config_infer,gene_names=[],true_params=[],synthetic_flag=0,new_model=False,synthetic_data=[]):
"""Infer the underlying dynamic parameters of a state transition network from observed data.
Args:
observed_data (numpy.ndarray): The observed dataset containing state counts.
config_infer (Namespace): Configuration object containing inference parameters.
gene_names (list, optional): List of gene names corresponding to the observed data. Defaults to [].
true_params (list, optional): List of true parameter values for validation. Defaults to [].
synthetic_flag (int, optional): Flag indicating whether synthetic data is used (1) or not (0). Defaults to 0.
new_model (bool, optional): Whether to initialize a new model or use a pre-trained model. Defaults to False.
synthetic_data (list, optional): Synthetic dataset for model validation, if applicable. Defaults to [].
Returns:
dict: Contains the inferred parameters, loss values, KL divergence values, and other evaluation metrics.
"""
truncated_index=-50
filtered_observed_data = observed_data
figure_dir, result_dir = config_infer.figure_dir,config_infer.result_dir
args = config_infer
test_params = 6
test_params = min(test_params,filtered_observed_data.shape[0])
iter_epoch = args.iteration_steps
random_r0_number = args.random_r0_number
model_name = args.model
args.lr = args.inference_lr
args.loss_type = args.inference_loss_func
if model_name=="arl":
lr_val = 1
elif model_name == "on_off_nm":
lr_val = 8
else:
lr_val = 1
bs = 1000
args.lr = lr_val
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
CUDA_LAUNCH_BLOCKING=1
if new_model:
rub_val,rlb_val = np.array(args.rub_val),np.array(args.rlb_val)
rub = torch.log(torch.tensor(rub_val, device=args.device))
rlb = torch.log(torch.tensor(rlb_val, device=args.device))
r0= np.random.uniform(low=rlb_val, high=rub_val)
grad_thresh = [1 for i in range(len(r0))]
x_labels = args.state_name
observed_data_index = args.observed_data_index
observed_sample_index = args.observed_data_index
weight_path = args.weight_nn_fine_tune_path
else:
rub_val,rlb_val,rub,rlb,r0,grad_thresh,x_labels,observed_data_index,observed_sample_index = get_model_config(model_name,args)
weight_path = pkg_resources.resource_filename(__name__, 'weights/{}_final.pt'.format(args.model))
param_num = args.num_reactions
state = torch.load(weight_path)
net = NNmodel_FT(**vars(args))
net.load_state_dict(state["net_state_dict"])
net.to(args.device)
net.eval()
for param in net.parameters():
param.requires_grad = False
es_param_alls,es_loss_all_nps,es_param_all_nps,sub_filtered_observed_data_li = [],[],[],[]
valid_gene_names=[]
valid_true_params = []
best_kl_val = []
kl_true_data_li = []
if len(gene_names)==0:
gene_names = list(range(test_params))
for k in range(test_params):
try:
gc.collect()
torch.cuda.empty_cache()
data_number = args.batch_size//2
if synthetic_flag:
unobserved_number = 0
else:
unobserved_number = args.num_species - len(observed_data_index)
if unobserved_number==0:
single_observed_data, counts = np.unique(filtered_observed_data[k], axis=0, return_counts=True)
loss_weight = counts/sum(counts)
sub_filtered_observed_data = single_observed_data
args.batch_size = sub_filtered_observed_data.shape[0]
else:
if len(observed_sample_index)==1:
single_observed_data, counts = np.unique(filtered_observed_data[k], axis=0, return_counts=True)
loss_weight = counts/sum(counts)
sub_filtered_observed_data = single_observed_data.T.reshape(len(single_observed_data),1)
args.batch_size = len(single_observed_data)*2
elif len(observed_sample_index)==2:
single_observed_data, counts = np.unique(filtered_observed_data[k], axis=0, return_counts=True)
loss_weight = counts/sum(counts)
sub_filtered_observed_data = single_observed_data
args.batch_size = len(counts) * (args.num_species - len(observed_data_index))*2
else:
single_observed_data, counts = np.unique(filtered_observed_data[k], axis=0, return_counts=True)
sub_filtered_observed_data = single_observed_data
args.batch_size = len(counts) * (args.num_species - len(observed_data_index))*2
loss_weight = counts/sum(counts)
es_param_all,es_loss_all,kl_divergence_val_li,best_param_li = [],[],[],[]
optimizer = net.configure_optimizers(args.weight_decay, args.lr,(args.beta1, args.beta2),device_type=args.device)
for l in range(random_r0_number):
r0 = np.random.uniform(low=rlb_val, high=rub_val)
if args.model=="arl":
r0[2]=r0[2]+2
temp_loss = []
temp_param = []
x,y_label,osp,sp,temp_sp = generate_nn_input_observed_data2(r0,args,sub_filtered_observed_data,net,observed_data_index,unobserved_number=unobserved_number)
update_flag = 0
for i in range(iter_epoch):
new_grad_means, loss_val = calculate_grad(r0,args,net,optimizer,param_num=param_num,grad_thresh=grad_thresh, y_label1=y_label,sp1=sp,loss_weight=loss_weight,unobserved_number=unobserved_number)
loss_cross = loss_val
nsp = (sp.detach() - new_grad_means.detach()).detach()
for i_param in range(param_num):
lower_bound, upper_bound = rlb[i_param],rub[i_param]
nsp[:param_num,i_param] = torch.clamp(nsp[:param_num,i_param], min=lower_bound, max=upper_bound)
sp=nsp.detach().cuda()
temp_sp = nsp.detach()
temp_loss_val = loss_cross.detach().cpu().numpy()
temp_param_val = temp_sp[0,:param_num].cpu().numpy()
temp_param.append(temp_param_val)
temp_loss.append(temp_loss_val)
if len(temp_loss)>20 and (temp_loss_val > min(temp_loss)):
update_flag = update_flag+1
if update_flag>1:
break
del loss_cross
del new_grad_means
es_loss_all.extend(temp_loss[truncated_index:]) # -20
es_param_all.extend(temp_param[truncated_index:])
es_param = es_param_all[-1].reshape(1,es_param_all[-1].shape[0])
es_param_prompt = np.hstack((es_param, np.tile(args.initial_value, (es_param.shape[0], 1))))
if synthetic_flag:
if len(synthetic_data)>0:
kl_true_data = synthetic_data[k]
else:
kl_true_data = filtered_observed_data[k]
else:
if len(observed_sample_index)==1:
kl_true_data = filtered_observed_data[k].reshape(filtered_observed_data[k].shape[0],1)
else:
kl_true_data = filtered_observed_data[k]
del optimizer
es_loss_all_np = np.hstack(es_loss_all)
es_param_all_np = np.vstack(es_param_all)
weights = np.exp(-es_loss_all_np)/np.sum(np.exp(-es_loss_all_np))
best_param = get_max_density_point(es_param_all_np,weights)
best_param = best_param.reshape(1,best_param.shape[0])
best_param_prompt = np.hstack((best_param, np.tile(args.initial_value, (best_param.shape[0], 1))))
best_param_samples = nn_sample(best_param_prompt,net,args=args, param_indexs=range(0,best_param.shape[0],1),bs=1000,true_param=False)
if synthetic_flag:
kl_divergence_val = calculate_kl_divergence(kl_true_data,best_param_samples[0,:,:])
else:
kl_divergence_val = calculate_kl_divergence(kl_true_data,best_param_samples[0,:,observed_sample_index].T)
best_kl_val.append(kl_divergence_val)
es_param_alls.append(best_param.flatten())
es_loss_all_nps.append(es_loss_all_np)
es_param_all_nps.append(es_param_all_np)
sub_filtered_observed_data_li.append(filtered_observed_data[k])
kl_true_data_li.append(kl_true_data)
valid_gene_names.append(gene_names[k])
if len(true_params)>0:
valid_true_params.append(true_params[k])
except Exception as e:
print(f"error:{e}")
continue
torch.cuda.empty_cache()
observed_datas = np.array(sub_filtered_observed_data_li)
kl_true_data_np = np.array(kl_true_data_li)
es_params=np.array(es_param_all_nps)
es_param = np.array(es_param_alls)
es_param_prompt = np.hstack((es_param, np.tile(args.initial_value, (es_param.shape[0], 1))))
nn_sample_datas = nn_sample(es_param_prompt,net,args=args,param_indexs=range(0,es_param.shape[0],1),bs=min(args.batch_size,1000),true_param=False)
es_loss=np.array(es_loss_all_nps)
valid_gene_names = np.array(valid_gene_names)
valid_true_params = np.array(valid_true_params)
best_kl_val = np.array(best_kl_val)
if len(result_dir)>0:
if not os.path.exists(result_dir):
os.makedirs(result_dir)
np.savez(result_dir +'inferring_result_{}.npz'.format(model_name), observed_datas=observed_datas,nn_sample_datas=nn_sample_datas,es_params=es_params,es_loss=es_loss,es_param=es_param,gene_names=valid_gene_names,kl_val=best_kl_val,true_params = valid_true_params,synthetic_data=kl_true_data_np)
result_dict = {}
result_dict["observed_datas"] = observed_datas
result_dict["nn_sample_datas"] = nn_sample_datas
result_dict["es_params"] = es_params
result_dict["es_param"] = es_param
result_dict["es_loss"] = es_loss
result_dict["model"] = config_infer.model
result_dict["state_name"] = config_infer.state_name
return result_dict