[Feature] Rule Based Model (#92)

Fixes #89

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#92
This commit is contained in:
2025-09-09 19:32:12 +12:00
parent 1a6608287d
commit 5477b5b3d4
10 changed files with 560 additions and 185 deletions

View File

@ -289,6 +289,12 @@ class Dataset:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def trig(self, na_replacement=0):
res = self.__getitem__((slice(None), slice(self._triggered[0], self._triggered[1])))
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def y(self, na_replacement=0):
res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None)))
@ -324,7 +330,7 @@ class Dataset:
pickle.dump(self, fh)
@staticmethod
def load(path: 'Path'):
def load(path: 'Path') -> 'Dataset':
import pickle
return pickle.load(open(path, "rb"))
@ -553,6 +559,68 @@ class EnsembleClassifierChain:
return labels / self.num_chains
class RelativeReasoning:
def __init__(self, start_index: int, end_index: int):
self.start_index: int = start_index
self.end_index: int = end_index
self.winmap: Dict[int, List[int]] = defaultdict(list)
self.min_count: int = 5
self.max_count: int = 0
def fit(self, X, Y):
n_instances = len(Y)
n_attributes = len(Y[0])
for i in range(n_attributes):
for j in range(n_attributes):
if i == j:
continue
countwin = 0
countloose = 0
countboth = 0
for k in range(n_instances):
vi = Y[k][i]
vj = Y[k][j]
if vi is None or vj is None:
continue
if vi < vj:
countwin += 1
elif vi > vj:
countloose += 1
elif vi == vj and vi == 1: # tie
countboth += 1
# We've seen more than self.min_count wins, more wins than loosing, no looses and no ties
if (
countwin >= self.min_count and
countwin > countloose and
(
countloose <= self.max_count or
self.max_count < 0
) and
countboth == 0
):
self.winmap[i].append(j)
def predict(self, X):
res = np.zeros((len(X), (self.end_index + 1 - self.start_index)))
for inst_idx, inst in enumerate(X):
for i, t in enumerate(inst[self.start_index: self.end_index + 1]):
res[inst_idx][i] = t
if t:
for i2, t2 in enumerate(inst[self.start_index: self.end_index + 1]):
if i != i2 and i2 in self.winmap.get(i, []) and X[t2]:
res[inst_idx][i] = 0
return res
def predict_proba(self, X):
return self.predict(X)
class ApplicabilityDomainPCA(PCA):