Skip to content
Snippets Groups Projects
Unverified Commit 7531515c authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

James review changes

Threshold checks done outside of _semantic_classify.

Testing more efficient as not using dl._query() accross every threshold.
parent b23d85a6
No related branches found
No related tags found
No related merge requests found
This diff is collapsed.
......@@ -18,9 +18,8 @@ class DecisionLayer:
def __call__(self, text: str, _method: str='raw', _threshold: float=0.5):
results = self._query(text)
decision = self._semantic_classify(results, _method=_method, _threshold=_threshold)
# return decision
return decision
top_class, top_class_scores = self._semantic_classify(results)
# TODO: Once we determine a threshold methodlogy, we can use it here.
def add(self, decision: Decision):
self._add_decision(devision=decision)
......@@ -67,47 +66,19 @@ class DecisionLayer:
]
def _semantic_classify(self, query_results: dict, _method: str='raw', _threshold: float=0.5):
"""Given some text, categorizes."""
def _semantic_classify(self, query_results: dict):
# Initialize score dictionaries
scores_by_class = {}
highest_score_by_class = {}
# Define valid methods
valid_methods = ['raw', 'tan', 'max_score_in_top_class']
# Check if method is valid
if _method not in valid_methods:
raise ValueError(f"Invalid method: {_method}")
# Apply the scoring system to the results and group by category
for result in query_results:
decision = result['decision']
score = result['score']
# Apply tan transformation if method is 'tan'
if _method == 'tan':
score = np.tan(score * (np.pi / 2))
# Update scores_by_class
scores_by_class[decision] = scores_by_class.get(decision, 0) + score
# Update highest_score_by_class for 'max_score_in_top_class' method
if _method == 'max_score_in_top_class':
highest_score_by_class[decision] = max(score, highest_score_by_class.get(decision, 0))
# Sort the categories by score in descending order
sorted_classes = sorted(scores_by_class.items(), key=lambda x: x[1], reverse=True)
# Determine if the score is sufficiently high
predicted_class = None
if sorted_classes:
top_class, top_score = sorted_classes[0]
if _method == 'max_score_in_top_class':
top_score = highest_score_by_class[top_class]
if top_score > _threshold:
predicted_class = top_class
# Return the category with the highest total score
return predicted_class, scores_by_class
\ No newline at end of file
decision = result['decision']
if decision in scores_by_class:
scores_by_class[decision].append(score)
else:
scores_by_class[decision] = [score]
# Calculate total score for each class
total_scores = {decision: sum(scores) for decision, scores in scores_by_class.items()}
top_class = max(total_scores, key=total_scores.get, default=None)
# Return the top class and its associated scores
return top_class, scores_by_class.get(top_class, [])
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment