forked from enviPath/enviPy
[Feature] Rule Based Model (#92)
Fixes #89 Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#92
This commit is contained in:
@ -289,6 +289,12 @@ class Dataset:
|
||||
res = [[x if x is not None else na_replacement for x in row] for row in res]
|
||||
return res
|
||||
|
||||
def trig(self, na_replacement=0):
|
||||
res = self.__getitem__((slice(None), slice(self._triggered[0], self._triggered[1])))
|
||||
if na_replacement is not None:
|
||||
res = [[x if x is not None else na_replacement for x in row] for row in res]
|
||||
return res
|
||||
|
||||
|
||||
def y(self, na_replacement=0):
|
||||
res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None)))
|
||||
@ -324,7 +330,7 @@ class Dataset:
|
||||
pickle.dump(self, fh)
|
||||
|
||||
@staticmethod
|
||||
def load(path: 'Path'):
|
||||
def load(path: 'Path') -> 'Dataset':
|
||||
import pickle
|
||||
return pickle.load(open(path, "rb"))
|
||||
|
||||
@ -553,6 +559,68 @@ class EnsembleClassifierChain:
|
||||
return labels / self.num_chains
|
||||
|
||||
|
||||
class RelativeReasoning:
|
||||
def __init__(self, start_index: int, end_index: int):
|
||||
self.start_index: int = start_index
|
||||
self.end_index: int = end_index
|
||||
self.winmap: Dict[int, List[int]] = defaultdict(list)
|
||||
self.min_count: int = 5
|
||||
self.max_count: int = 0
|
||||
|
||||
def fit(self, X, Y):
|
||||
n_instances = len(Y)
|
||||
n_attributes = len(Y[0])
|
||||
|
||||
for i in range(n_attributes):
|
||||
for j in range(n_attributes):
|
||||
if i == j:
|
||||
continue
|
||||
|
||||
countwin = 0
|
||||
countloose = 0
|
||||
countboth = 0
|
||||
|
||||
for k in range(n_instances):
|
||||
vi = Y[k][i]
|
||||
vj = Y[k][j]
|
||||
|
||||
if vi is None or vj is None:
|
||||
continue
|
||||
|
||||
if vi < vj:
|
||||
countwin += 1
|
||||
elif vi > vj:
|
||||
countloose += 1
|
||||
elif vi == vj and vi == 1: # tie
|
||||
countboth += 1
|
||||
|
||||
# We've seen more than self.min_count wins, more wins than loosing, no looses and no ties
|
||||
if (
|
||||
countwin >= self.min_count and
|
||||
countwin > countloose and
|
||||
(
|
||||
countloose <= self.max_count or
|
||||
self.max_count < 0
|
||||
) and
|
||||
countboth == 0
|
||||
):
|
||||
self.winmap[i].append(j)
|
||||
|
||||
def predict(self, X):
|
||||
res = np.zeros((len(X), (self.end_index + 1 - self.start_index)))
|
||||
|
||||
for inst_idx, inst in enumerate(X):
|
||||
for i, t in enumerate(inst[self.start_index: self.end_index + 1]):
|
||||
res[inst_idx][i] = t
|
||||
if t:
|
||||
for i2, t2 in enumerate(inst[self.start_index: self.end_index + 1]):
|
||||
if i != i2 and i2 in self.winmap.get(i, []) and X[t2]:
|
||||
res[inst_idx][i] = 0
|
||||
|
||||
return res
|
||||
|
||||
def predict_proba(self, X):
|
||||
return self.predict(X)
|
||||
|
||||
|
||||
class ApplicabilityDomainPCA(PCA):
|
||||
|
||||
Reference in New Issue
Block a user