nvflare.app_common.workflows.splitnn_workflow module

class SplitNNConstants[source]

Bases: object

BATCH_INDICES = '_splitnn_batch_indices_'
BATCH_SIZE = '_splitnn_batch_size_'
DATA = '_splitnn_data_'
TARGET_NAMES = '_splitnn_target_names_'
TASK_INIT_MODEL = '_splitnn_task_init_model_'
TASK_RESULT = '_splitnn_task_result_'
TASK_TRAIN = '_splitnn_task_train_'
TASK_TRAIN_LABEL_STEP = '_splitnn_task_train_label_step_'
TASK_VALID_LABEL_STEP = '_splitnn_task_valid_label_step_'
TIMEOUT = 60.0
class SplitNNController(num_rounds: int = 5000, start_round: int = 0, persistor_id='persistor', shareable_generator_id='shareable_generator', init_model_task_name='_splitnn_task_init_model_', train_task_name='_splitnn_task_train_', task_timeout: int = 10, ignore_result_error: bool = True, batch_size: int = 256)[source]

Bases: Controller

The controller for Split Learning Workflow.

The SplitNNController workflow defines Federated training on all clients. The model persistor (persistor_id) is used to load the initial global model which is sent to all clients. Each clients sends it’s updated weights after local training which is aggregated (aggregator_id). The shareable generator is used to convert the aggregated weights to shareable and shareable back to weights. The model_persistor also saves the model after training.

  • num_rounds (int, optional) – The total number of training rounds. Defaults to 5.

  • start_round (int, optional) – Start round for training. Defaults to 0.

  • persistor_id (str, optional) – ID of the persistor component. Defaults to “persistor”.

  • shareable_generator_id (str, optional) – ID of the shareable generator. Defaults to “shareable_generator”.

  • init_model_task_name – Task name used to initialize the local models.

  • train_task_name – Task name used for split learning.

  • task_timeout (int, optional) – timeout (in sec) to determine if one client fails to request the task which it is assigned to. Defaults to 10.

  • ignore_result_error (bool, optional) – whether this controller can proceed if result has errors. Defaults to True.

  • TypeError – when any of input arguments does not have correct type

  • ValueError – when any of input arguments is out of range

control_flow(abort_signal: Signal, fl_ctx: FLContext)[source]

This is the control logic for the RUN.

NOTE: this is running in a separate thread, and its life is the duration of the RUN.

  • fl_ctx – the FL context

  • abort_signal – the abort signal. If triggered, this method stops waiting and returns to the caller.

handle_event(event_type: str, fl_ctx: FLContext)[source]

Handles events.

  • event_type (str) – event type fired by workflow.

  • fl_ctx (FLContext) – FLContext information.

process_result_of_unknown_task(client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext)[source]

Process result when no task is found for it.

This is called when a result submission is received from a client, but no standing task can be found for it (from the task queue)

This could happen when: - the client’s submission is too late - the task is already completed - the Controller lost the task, e.g. the Server is restarted

  • client – the client that the result comes from

  • task_name – the name of the task

  • client_task_id – ID of the task

  • result – the result from the client

  • fl_ctx – the FL context that comes with the client’s submission

start_controller(fl_ctx: FLContext)[source]

Starts the controller.

This method is called at the beginning of the RUN.

  • fl_ctx – the FL context. You can use this context to access services provided by the

  • example (framework. For) –

  • your (you can get Command Register from it and register) –

  • modules. (admin command) –

stop_controller(fl_ctx: FLContext)[source]

Stops the controller.

This method is called right before the RUN is ended.

  • fl_ctx – the FL context. You can use this context to access services provided by the

  • example (framework. For) –

  • your (you can get Command Register from it and unregister) –

  • modules. (admin command) –

class SplitNNDataKind[source]

Bases: object

ACTIVATIONS = '_splitnn_activations_'
GRADIENT = '_splitnn_gradient_'