[docs]@dataclassclassGREAMolecularPredictor(GNNMolecularPredictor):"""This predictor implements a Graph Rationalization model called GREA. The full name of GREA is Graph Rationalization with Environment-based Augmentations. During model training, it learns the rationales (explainable subgraphs) and use them for molecular property prediction tasks. References ---------- - Graph Rationalization with Environment-based Augmentations. https://dl.acm.org/doi/10.1145/3534678.3539347 - Code: https://github.com/liugangcode/GREA :param gamma: GREA-specific parameter. Default is 0.4. :type gamma: float """# GREA-specific parametergamma:float=0.4# Override parent defaultsmodel_name:str="GREAMolecularPredictor"model_class:Type[GREA]=field(default=GREA,init=False)def__post_init__(self):super().__post_init__()@staticmethoddef_get_param_names()->List[str]:return["gamma"]+GNNMolecularPredictor._get_param_names()def_get_default_search_space(self):search_space=super()._get_default_search_space().copy()search_space["gamma"]=ParameterSpec(ParameterType.FLOAT,(0.1,0.9))returnsearch_spacedef_get_model_params(self,checkpoint:Optional[Dict]=None)->Dict[str,Any]:base_params=super()._get_model_params(checkpoint)ifcheckpointand"hyperparameters"incheckpoint:base_params["gamma"]=checkpoint["hyperparameters"].get("gamma",self.gamma)else:base_params["gamma"]=self.gammabase_params.pop("graph_pooling",None)returnbase_params
[docs]defpredict(self,X:List[str])->Dict[str,Union[np.ndarray,List[List]]]:"""Make predictions using the fitted model. Parameters ---------- X : List[str] List of SMILES strings to make predictions for Returns ------- Dict[str, np.ndarray] Dictionary containing: - 'prediction': Model predictions (shape: [n_samples, n_tasks]) - 'variance': Prediction variances (shape: [n_samples, n_tasks]) - 'node_importance': A nested list where the outer list has length n_samples and each inner list has length n_nodes for that molecule """self._check_is_fitted()# Convert to PyTorch Geometric format and create loaderX,_=self._validate_inputs(X)dataset=self._convert_to_pytorch_data(X)loader=DataLoader(dataset,batch_size=self.batch_size,shuffle=False)# Make predictionsself.model=self.model.to(self.device)self.model.eval()predictions=[]variances=[]node_scores=[]withtorch.no_grad():forbatchintqdm(loader,disable=notself.verbose):batch=batch.to(self.device)out=self.model(batch)predictions.append(out["prediction"].cpu().numpy())variances.append(out["variance"].cpu().numpy())node_scores.extend(out["score"])ifpredictionsandvariances:return{"prediction":np.concatenate(predictions,axis=0),"variance":np.concatenate(variances,axis=0),"node_importance":node_scores,}else:warnings.warn("No valid predictions could be made from the input data. Returning empty results.")return{"prediction":np.array([]),"variance":np.array([]),"node_importance":np.array([])}