nvflare.app_opt.pt.job_config.fed_sag_mlflow module¶
- class SAGMLFlowJob(initial_model: Module, n_clients: int, num_rounds: int, name: str = 'fed_job', min_clients: int = 1, mandatory_clients: List[str] | None = None, key_metric: str = 'accuracy', tracking_uri=None, kwargs=None, artifact_location=None)[source]¶
Bases:
BaseFedJob
PyTorch ScatterAndGather with MLFlow Job.
Configures server side ScatterAndGather controller, persistor with initial model, and widgets.
User must add executors.
- Parameters:
initial_model (nn.Module) – initial PyTorch Model
n_clients (int) – number of clients for this job
num_rounds (int) – number of rounds for FedAvg
name (name, optional) – name of the job. Defaults to “fed_job”
min_clients (int, optional) – the minimum number of clients for the job. Defaults to 1.
mandatory_clients (List[str], optional) – mandatory clients to run the job. Default None.
key_metric (str, optional) – Metric used to determine if the model is globally best. if metrics are a dict, key_metric can select the metric used for global model selection. Defaults to “accuracy”.
kwargs – kwargs dict