# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Dict, List, Union
from .dxo import DXO, DataKind, from_shareable
from .filter import Filter, FilterContextKey
from .fl_constant import ReturnCode
from .fl_context import FLContext
from .shareable import Shareable
from .utils.fl_context_utils import add_job_audit_event
[docs]class DXOFilter(Filter, ABC):
"""
This is the base class for DXO-based filters
"""
def __init__(self, supported_data_kinds: Union[None, List[str]], data_kinds_to_filter: Union[None, List[str]]):
"""
Args:
supported_data_kinds: kinds of DXO this filter supports. Empty means all kinds.
data_kinds_to_filter: kinds of DXO data to filter. Empty means all kinds.
"""
Filter.__init__(self)
if supported_data_kinds and not isinstance(supported_data_kinds, list):
raise ValueError(f"supported_data_kinds must be a list of str but got {type(supported_data_kinds)}")
if data_kinds_to_filter and not isinstance(data_kinds_to_filter, list):
raise ValueError(f"data_kinds_to_filter must be a list of str but got {type(data_kinds_to_filter)}")
if supported_data_kinds and data_kinds_to_filter:
if not all(dk in supported_data_kinds for dk in data_kinds_to_filter):
raise ValueError(f"invalid data kinds: {data_kinds_to_filter}. Only support {data_kinds_to_filter}")
if not data_kinds_to_filter:
data_kinds_to_filter = supported_data_kinds
self.data_kinds = data_kinds_to_filter
[docs] def process(self, shareable: Shareable, fl_ctx: FLContext):
rc = shareable.get_return_code()
if rc != ReturnCode.OK:
# don't process if RC not OK
return shareable
try:
dxo = from_shareable(shareable)
except:
# not a DXO based shareable - pass
return shareable
if dxo.data is None:
self.log_debug(fl_ctx, "DXO has no data to filter")
return shareable
start = [dxo]
self._filter_dxos(start, shareable, fl_ctx)
result_dxo = start[0]
return result_dxo.update_shareable(shareable)
[docs] @abstractmethod
def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Union[None, DXO]:
"""Subclass must implement this method to filter the provided DXO
Args:
dxo: the DXO to be filtered
shareable: the shareable that the dxo belongs to
fl_ctx: the FL context
Returns:
A DXO object that is the result of the filtering, if filtered;
None if not filtered.
"""
pass
def _apply_filter(self, dxo: DXO, shareable, fl_ctx: FLContext) -> DXO:
if not dxo.data:
self.log_debug(fl_ctx, "DXO has no data to filter")
return dxo
filter_name = self.__class__.__name__
result = self.process_dxo(dxo, shareable, fl_ctx)
if not result:
# not filtered
result = dxo
elif not isinstance(result, DXO):
raise RuntimeError(f"Result from {filter_name} is {type(result)} - must be DXO")
else:
if result != dxo:
# result is a new DXO - copy filter history from original dxo
result.add_filter_history(dxo.get_filter_history())
result.add_filter_history(filter_name)
chain_type = self.get_prop(FilterContextKey.CHAIN_TYPE, "?")
source = self.get_prop(FilterContextKey.SOURCE, "?")
add_job_audit_event(fl_ctx=fl_ctx, msg=f"applied filter: {filter_name}@{source} on {chain_type}")
return result
def _filter_dxos(self, dxo_collection: Union[List[DXO], Dict[str, DXO]], shareable, fl_ctx):
if isinstance(dxo_collection, list):
for i in range(len(dxo_collection)):
v = dxo_collection[i]
if not isinstance(v, DXO):
continue
if v.data_kind == DataKind.COLLECTION:
self._filter_dxos(v.data, shareable, fl_ctx)
elif not self.data_kinds or v.data_kind in self.data_kinds:
dxo_collection[i] = self._apply_filter(v, shareable, fl_ctx)
elif isinstance(dxo_collection, dict):
for k, v in dxo_collection.items():
assert isinstance(v, DXO)
if v.data_kind == DataKind.COLLECTION:
self._filter_dxos(v.data, shareable, fl_ctx)
elif not self.data_kinds or v.data_kind in self.data_kinds:
dxo_collection[k] = self._apply_filter(v, shareable, fl_ctx)
else:
raise ValueError(f"DXO COLLECTION must be a dict or list but got {type(dxo_collection)}")