diff --git a/coverage.xml b/coverage.xml index ef0c214f529c075212da1cf239a8e603949c24ab..321f6c5c6b8cb01dd3094f0a1fbdc62d4465fe4c 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,12 +1,12 @@ <?xml version="1.0" ?> -<coverage version="7.4.3" timestamp="1710353742486" lines-valid="2080" lines-covered="1625" line-rate="0.7812" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0"> +<coverage version="7.4.3" timestamp="1710430201229" lines-valid="2107" lines-covered="1642" line-rate="0.7793" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0"> <!-- Generated by coverage.py: https://coverage.readthedocs.io/en/7.4.3 --> <!-- Based on https://raw.githubusercontent.com/cobertura/web/master/htdocs/xml/coverage-04.dtd --> <sources> <source>/Users/andreped/workspace/semantic-router/semantic_router</source> </sources> <packages> - <package name="." line-rate="0.8962" branch-rate="0" complexity="0"> + <package name="." line-rate="0.8928" branch-rate="0" complexity="0"> <classes> <class name="__init__.py" filename="__init__.py" complexity="0" line-rate="1" branch-rate="0"> <methods/> @@ -18,7 +18,7 @@ <line number="7" hits="1"/> </lines> </class> - <class name="hybrid_layer.py" filename="hybrid_layer.py" complexity="0" line-rate="0.972" branch-rate="0"> + <class name="hybrid_layer.py" filename="hybrid_layer.py" complexity="0" line-rate="0.958" branch-rate="0"> <methods/> <lines> <line number="1" hits="1"/> @@ -33,104 +33,116 @@ <line number="18" hits="1"/> <line number="19" hits="1"/> <line number="21" hits="1"/> - <line number="29" hits="1"/> <line number="30" hits="1"/> - <line number="32" hits="1"/> - <line number="33" hits="0"/> + <line number="31" hits="1"/> + <line number="33" hits="1"/> <line number="34" hits="0"/> - <line number="36" hits="1"/> - <line number="38" hits="1"/> + <line number="35" hits="0"/> + <line number="37" hits="1"/> <line number="39" hits="1"/> <line number="40" hits="1"/> - <line number="41" hits="0"/> - <line number="42" hits="1"/> + <line number="41" hits="1"/> + <line number="42" hits="0"/> <line number="43" hits="1"/> - <line number="46" hits="1"/> + <line number="44" hits="1"/> + <line number="45" hits="0"/> <line number="48" hits="1"/> - <line number="52" hits="1"/> - <line number="54" hits="1"/> + <line number="49" hits="1"/> + <line number="50" hits="1"/> + <line number="53" hits="1"/> <line number="55" hits="1"/> - <line number="56" hits="1"/> - <line number="57" hits="1"/> - <line number="58" hits="1"/> <line number="59" hits="1"/> <line number="61" hits="1"/> + <line number="62" hits="1"/> <line number="63" hits="1"/> <line number="64" hits="1"/> + <line number="65" hits="1"/> <line number="66" hits="1"/> - <line number="67" hits="1"/> - <line number="69" hits="1"/> + <line number="68" hits="1"/> + <line number="70" hits="1"/> <line number="71" hits="1"/> + <line number="73" hits="1"/> <line number="74" hits="1"/> <line number="76" hits="1"/> - <line number="77" hits="1"/> - <line number="80" hits="1"/> - <line number="82" hits="1"/> - <line number="85" hits="1"/> - <line number="86" hits="1"/> - <line number="88" hits="1"/> + <line number="78" hits="1"/> + <line number="81" hits="1"/> + <line number="83" hits="1"/> + <line number="84" hits="1"/> + <line number="87" hits="1"/> <line number="89" hits="1"/> - <line number="90" hits="1"/> <line number="92" hits="1"/> - <line number="94" hits="1"/> + <line number="93" hits="1"/> <line number="95" hits="1"/> - <line number="98" hits="1"/> + <line number="96" hits="1"/> + <line number="97" hits="1"/> <line number="99" hits="1"/> + <line number="101" hits="1"/> <line number="102" hits="1"/> - <line number="103" hits="1"/> - <line number="104" hits="1"/> + <line number="105" hits="1"/> + <line number="106" hits="1"/> + <line number="109" hits="1"/> <line number="110" hits="1"/> <line number="111" hits="1"/> - <line number="113" hits="1"/> - <line number="119" hits="1"/> + <line number="117" hits="1"/> + <line number="118" hits="1"/> <line number="120" hits="1"/> - <line number="122" hits="1"/> - <line number="128" hits="1"/> - <line number="133" hits="1"/> - <line number="134" hits="1"/> - <line number="136" hits="1"/> - <line number="137" hits="1"/> - <line number="139" hits="1"/> + <line number="126" hits="1"/> + <line number="127" hits="1"/> + <line number="129" hits="1"/> + <line number="135" hits="1"/> + <line number="140" hits="1"/> <line number="141" hits="1"/> <line number="143" hits="1"/> <line number="144" hits="1"/> - <line number="145" hits="1"/> - <line number="147" hits="1"/> + <line number="146" hits="1"/> <line number="148" hits="1"/> - <line number="149" hits="1"/> <line number="150" hits="1"/> + <line number="151" hits="1"/> <line number="152" hits="1"/> - <line number="153" hits="1"/> <line number="154" hits="1"/> + <line number="155" hits="1"/> <line number="156" hits="1"/> <line number="157" hits="1"/> <line number="159" hits="1"/> <line number="160" hits="1"/> - <line number="162" hits="1"/> + <line number="161" hits="1"/> + <line number="163" hits="1"/> <line number="164" hits="1"/> - <line number="165" hits="1"/> <line number="166" hits="1"/> - <line number="168" hits="1"/> + <line number="167" hits="1"/> <line number="169" hits="1"/> - <line number="170" hits="1"/> <line number="171" hits="1"/> <line number="172" hits="1"/> <line number="173" hits="1"/> - <line number="174" hits="1"/> + <line number="175" hits="1"/> <line number="176" hits="1"/> + <line number="177" hits="1"/> + <line number="178" hits="1"/> <line number="179" hits="1"/> <line number="180" hits="1"/> - <line number="183" hits="1"/> - <line number="184" hits="1"/> - <line number="186" hits="1"/> + <line number="181" hits="1"/> + <line number="183" hits="0"/> <line number="187" hits="1"/> + <line number="188" hits="1"/> <line number="189" hits="1"/> <line number="190" hits="1"/> <line number="191" hits="1"/> + <line number="192" hits="1"/> <line number="193" hits="1"/> + <line number="195" hits="1"/> + <line number="198" hits="1"/> + <line number="202" hits="1"/> + <line number="205" hits="1"/> + <line number="206" hits="1"/> + <line number="208" hits="1"/> + <line number="209" hits="1"/> + <line number="211" hits="1"/> + <line number="212" hits="1"/> + <line number="213" hits="1"/> + <line number="215" hits="1"/> </lines> </class> - <class name="layer.py" filename="layer.py" complexity="0" line-rate="0.8936" branch-rate="0"> + <class name="layer.py" filename="layer.py" complexity="0" line-rate="0.8889" branch-rate="0"> <methods/> <lines> <line number="1" hits="1"/> @@ -204,217 +216,232 @@ <line number="110" hits="1"/> <line number="111" hits="1"/> <line number="113" hits="1"/> - <line number="114" hits="1"/> + <line number="115" hits="1"/> + <line number="118" hits="1"/> <line number="119" hits="1"/> - <line number="120" hits="1"/> - <line number="122" hits="1"/> + <line number="121" hits="1"/> + <line number="125" hits="1"/> <line number="126" hits="1"/> - <line number="127" hits="1"/> - <line number="133" hits="1"/> + <line number="132" hits="1"/> + <line number="134" hits="1"/> <line number="135" hits="1"/> - <line number="136" hits="1"/> + <line number="138" hits="1"/> <line number="139" hits="1"/> - <line number="140" hits="1"/> - <line number="144" hits="1"/> - <line number="147" hits="1"/> - <line number="148" hits="0"/> + <line number="143" hits="1"/> + <line number="146" hits="1"/> + <line number="147" hits="0"/> + <line number="149" hits="1"/> <line number="150" hits="1"/> <line number="151" hits="1"/> <line number="152" hits="1"/> <line number="153" hits="1"/> - <line number="154" hits="1"/> + <line number="155" hits="1"/> <line number="156" hits="1"/> <line number="157" hits="1"/> - <line number="158" hits="1"/> + <line number="159" hits="1"/> <line number="160" hits="1"/> <line number="161" hits="1"/> <line number="162" hits="1"/> <line number="163" hits="1"/> <line number="164" hits="1"/> - <line number="165" hits="1"/> + <line number="166" hits="1"/> <line number="167" hits="1"/> - <line number="168" hits="1"/> - <line number="169" hits="0"/> + <line number="168" hits="0"/> + <line number="170" hits="1"/> <line number="171" hits="1"/> - <line number="172" hits="1"/> + <line number="174" hits="1"/> <line number="175" hits="1"/> <line number="176" hits="1"/> <line number="177" hits="1"/> - <line number="178" hits="1"/> - <line number="180" hits="1"/> - <line number="187" hits="1"/> + <line number="179" hits="1"/> <line number="188" hits="1"/> <line number="189" hits="1"/> <line number="190" hits="1"/> - <line number="194" hits="1"/> - <line number="196" hits="1"/> + <line number="191" hits="1"/> + <line number="195" hits="1"/> <line number="197" hits="1"/> <line number="198" hits="1"/> <line number="199" hits="1"/> + <line number="200" hits="1"/> <line number="201" hits="1"/> <line number="202" hits="1"/> - <line number="203" hits="1"/> + <line number="203" hits="0"/> + <line number="204" hits="1"/> <line number="205" hits="1"/> - <line number="207" hits="1"/> + <line number="206" hits="0"/> <line number="209" hits="1"/> - <line number="210" hits="1"/> - <line number="211" hits="1"/> - <line number="212" hits="0"/> - <line number="216" hits="0"/> - <line number="217" hits="1"/> - <line number="219" hits="1"/> - <line number="226" hits="1"/> - <line number="227" hits="1"/> + <line number="212" hits="1"/> + <line number="213" hits="1"/> + <line number="214" hits="1"/> + <line number="216" hits="1"/> + <line number="218" hits="1"/> + <line number="220" hits="1"/> + <line number="221" hits="1"/> + <line number="222" hits="1"/> + <line number="223" hits="0"/> + <line number="227" hits="0"/> <line number="228" hits="1"/> - <line number="229" hits="1"/> - <line number="231" hits="1"/> - <line number="232" hits="1"/> - <line number="234" hits="1"/> - <line number="235" hits="1"/> - <line number="236" hits="1"/> + <line number="230" hits="1"/> + <line number="237" hits="1"/> + <line number="238" hits="1"/> <line number="239" hits="1"/> - <line number="240" hits="0"/> - <line number="241" hits="0"/> - <line number="247" hits="0"/> - <line number="248" hits="0"/> - <line number="250" hits="0"/> - <line number="251" hits="1"/> - <line number="252" hits="1"/> - <line number="253" hits="0"/> - <line number="260" hits="1"/> + <line number="240" hits="1"/> + <line number="242" hits="1"/> + <line number="243" hits="1"/> + <line number="245" hits="1"/> + <line number="246" hits="1"/> + <line number="247" hits="1"/> + <line number="250" hits="1"/> + <line number="251" hits="0"/> + <line number="252" hits="0"/> + <line number="258" hits="0"/> + <line number="259" hits="0"/> + <line number="261" hits="0"/> <line number="262" hits="1"/> - <line number="270" hits="1"/> - <line number="272" hits="1"/> - <line number="274" hits="1"/> - <line number="275" hits="1"/> - <line number="277" hits="1"/> + <line number="263" hits="1"/> + <line number="264" hits="0"/> + <line number="271" hits="1"/> + <line number="273" hits="1"/> <line number="281" hits="1"/> - <line number="282" hits="0"/> <line number="283" hits="1"/> + <line number="285" hits="1"/> + <line number="286" hits="1"/> <line number="288" hits="1"/> - <line number="290" hits="1"/> - <line number="291" hits="0"/> - <line number="297" hits="1"/> - <line number="298" hits="1"/> + <line number="292" hits="1"/> + <line number="293" hits="0"/> + <line number="294" hits="1"/> <line number="299" hits="1"/> - <line number="300" hits="1"/> <line number="301" hits="1"/> - <line number="303" hits="1"/> - <line number="304" hits="1"/> - <line number="305" hits="1"/> - <line number="306" hits="1"/> - <line number="307" hits="1"/> + <line number="302" hits="0"/> + <line number="308" hits="1"/> <line number="309" hits="1"/> <line number="310" hits="1"/> <line number="311" hits="1"/> <line number="312" hits="1"/> <line number="314" hits="1"/> <line number="315" hits="1"/> + <line number="316" hits="1"/> <line number="317" hits="1"/> - <line number="319" hits="1"/> + <line number="318" hits="1"/> <line number="320" hits="1"/> + <line number="321" hits="1"/> + <line number="322" hits="1"/> <line number="323" hits="1"/> + <line number="325" hits="1"/> + <line number="326" hits="1"/> <line number="328" hits="1"/> <line number="330" hits="1"/> <line number="331" hits="1"/> - <line number="333" hits="1"/> - <line number="334" hits="0"/> - <line number="336" hits="1"/> + <line number="334" hits="1"/> + <line number="339" hits="1"/> + <line number="341" hits="1"/> <line number="342" hits="1"/> - <line number="343" hits="1"/> <line number="344" hits="1"/> - <line number="345" hits="1"/> + <line number="345" hits="0"/> <line number="347" hits="1"/> - <line number="348" hits="1"/> - <line number="350" hits="1"/> - <line number="352" hits="0"/> - <line number="367" hits="1"/> - <line number="369" hits="1"/> - <line number="372" hits="1"/> - <line number="374" hits="1"/> - <line number="376" hits="1"/> - <line number="382" hits="1"/> + <line number="353" hits="1"/> + <line number="354" hits="1"/> + <line number="355" hits="1"/> + <line number="356" hits="1"/> + <line number="358" hits="1"/> + <line number="359" hits="1"/> + <line number="361" hits="1"/> + <line number="363" hits="0"/> + <line number="378" hits="1"/> + <line number="380" hits="1"/> + <line number="383" hits="1"/> <line number="385" hits="1"/> - <line number="386" hits="1"/> <line number="387" hits="1"/> - <line number="389" hits="1"/> - <line number="392" hits="1"/> <line number="393" hits="1"/> - <line number="395" hits="1"/> <line number="396" hits="1"/> <line number="397" hits="1"/> <line number="398" hits="1"/> - <line number="399" hits="1"/> <line number="400" hits="1"/> - <line number="401" hits="1"/> <line number="403" hits="1"/> + <line number="404" hits="1"/> <line number="406" hits="1"/> <line number="407" hits="1"/> + <line number="408" hits="1"/> + <line number="409" hits="1"/> <line number="410" hits="1"/> <line number="411" hits="1"/> - <line number="413" hits="0"/> + <line number="412" hits="1"/> <line number="414" hits="0"/> - <line number="416" hits="1"/> - <line number="417" hits="1"/> <line number="418" hits="1"/> + <line number="419" hits="1"/> <line number="420" hits="1"/> + <line number="421" hits="1"/> <line number="422" hits="1"/> + <line number="423" hits="1"/> + <line number="424" hits="1"/> <line number="426" hits="1"/> - <line number="427" hits="1"/> - <line number="428" hits="1"/> - <line number="432" hits="1"/> + <line number="429" hits="1"/> <line number="433" hits="1"/> - <line number="439" hits="1"/> - <line number="440" hits="1"/> - <line number="441" hits="1"/> + <line number="436" hits="1"/> + <line number="437" hits="1"/> + <line number="439" hits="0"/> + <line number="440" hits="0"/> + <line number="442" hits="1"/> <line number="443" hits="1"/> <line number="444" hits="1"/> - <line number="445" hits="1"/> - <line number="447" hits="1"/> - <line number="449" hits="1"/> + <line number="446" hits="1"/> + <line number="448" hits="1"/> + <line number="452" hits="1"/> <line number="453" hits="1"/> - <line number="455" hits="1"/> - <line number="463" hits="1"/> - <line number="464" hits="1"/> + <line number="454" hits="1"/> + <line number="458" hits="1"/> + <line number="459" hits="1"/> <line number="465" hits="1"/> <line number="466" hits="1"/> - <line number="468" hits="1"/> + <line number="467" hits="1"/> <line number="469" hits="1"/> + <line number="470" hits="1"/> <line number="471" hits="1"/> - <line number="472" hits="1"/> - <line number="474" hits="1"/> + <line number="473" hits="1"/> + <line number="475" hits="1"/> <line number="479" hits="1"/> <line number="481" hits="1"/> - <line number="483" hits="1"/> - <line number="484" hits="0"/> - <line number="485" hits="0"/> - <line number="487" hits="1"/> <line number="489" hits="1"/> - <line number="493" hits="1"/> + <line number="490" hits="1"/> + <line number="491" hits="1"/> + <line number="492" hits="1"/> <line number="494" hits="1"/> <line number="495" hits="1"/> - <line number="496" hits="1"/> + <line number="497" hits="1"/> <line number="498" hits="1"/> - <line number="499" hits="1"/> - <line number="501" hits="1"/> + <line number="500" hits="1"/> <line number="505" hits="1"/> - <line number="506" hits="1"/> - <line number="508" hits="1"/> + <line number="507" hits="1"/> <line number="509" hits="1"/> - <line number="510" hits="1"/> - <line number="511" hits="1"/> - <line number="512" hits="1"/> - <line number="514" hits="1"/> + <line number="510" hits="0"/> + <line number="511" hits="0"/> + <line number="513" hits="1"/> <line number="515" hits="1"/> - <line number="518" hits="1"/> + <line number="519" hits="1"/> + <line number="520" hits="1"/> + <line number="521" hits="1"/> + <line number="522" hits="1"/> <line number="524" hits="1"/> <line number="525" hits="1"/> - <line number="526" hits="1"/> - <line number="528" hits="1"/> - <line number="529" hits="1"/> - <line number="530" hits="1"/> + <line number="527" hits="1"/> + <line number="531" hits="1"/> + <line number="532" hits="1"/> + <line number="534" hits="1"/> + <line number="535" hits="1"/> + <line number="536" hits="1"/> + <line number="537" hits="1"/> <line number="538" hits="1"/> - <line number="542" hits="1"/> + <line number="540" hits="1"/> + <line number="541" hits="1"/> + <line number="544" hits="1"/> + <line number="550" hits="1"/> + <line number="551" hits="1"/> + <line number="552" hits="1"/> + <line number="554" hits="1"/> + <line number="555" hits="1"/> + <line number="556" hits="1"/> + <line number="564" hits="1"/> + <line number="568" hits="1"/> </lines> </class> <class name="linear.py" filename="linear.py" complexity="0" line-rate="1" branch-rate="0"> @@ -671,7 +698,7 @@ </class> </classes> </package> - <package name="encoders" line-rate="0.9574" branch-rate="0" complexity="0"> + <package name="encoders" line-rate="0.9485" branch-rate="0" complexity="0"> <classes> <class name="__init__.py" filename="encoders/__init__.py" complexity="0" line-rate="1" branch-rate="0"> <methods/> @@ -878,7 +905,7 @@ <line number="49" hits="1"/> </lines> </class> - <class name="fastembed.py" filename="encoders/fastembed.py" complexity="0" line-rate="0.8667" branch-rate="0"> + <class name="fastembed.py" filename="encoders/fastembed.py" complexity="0" line-rate="0.7" branch-rate="0"> <methods/> <lines> <line number="1" hits="1"/> @@ -903,12 +930,12 @@ <line number="33" hits="1"/> <line number="40" hits="1"/> <line number="42" hits="1"/> - <line number="43" hits="1"/> + <line number="43" hits="0"/> <line number="45" hits="1"/> - <line number="46" hits="1"/> - <line number="47" hits="1"/> - <line number="48" hits="1"/> - <line number="49" hits="1"/> + <line number="46" hits="0"/> + <line number="47" hits="0"/> + <line number="48" hits="0"/> + <line number="49" hits="0"/> <line number="50" hits="0"/> <line number="51" hits="0"/> </lines> diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index 9791786f5aeb622477d76d222c9c9c272e4e57bb..5f223384b9a5f2db983fd10074e353a5078424c3 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -25,6 +25,7 @@ class HybridRouteLayer: routes: List[Route] = [], alpha: float = 0.3, top_k: int = 5, + aggregation: str = "sum", ): self.encoder = encoder self.score_threshold = self.encoder.score_threshold @@ -39,6 +40,12 @@ class HybridRouteLayer: self.top_k = top_k if self.top_k < 1: raise ValueError(f"top_k needs to be >= 1, but was: {self.top_k}.") + self.aggregation = aggregation + if self.aggregation not in ["sum", "mean", "max"]: + raise ValueError( + f"Unsupported aggregation method chosen: {aggregation}. Choose either 'SUM', 'MEAN', or 'MAX'." + ) + self.aggregation_method = self._set_aggregation_method(self.aggregation) self.routes = routes if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr( self.sparse_encoder, "fit" @@ -165,6 +172,18 @@ class HybridRouteLayer: sparse = np.array(sparse) * (1 - self.alpha) return dense, sparse + def _set_aggregation_method(self, aggregation: str = "sum"): + if aggregation == "sum": + return lambda x: sum(x) + elif aggregation == "mean": + return lambda x: np.mean(x) + elif aggregation == "max": + return lambda x: max(x) + else: + raise ValueError( + f"Unsupported aggregation method chosen: {aggregation}. Choose either 'SUM', 'MEAN', or 'MAX'." + ) + def _semantic_classify(self, query_results: List[Dict]) -> Tuple[str, List[float]]: scores_by_class: Dict[str, List[float]] = {} for result in query_results: @@ -176,7 +195,10 @@ class HybridRouteLayer: scores_by_class[route] = [score] # Calculate total score for each class - total_scores = {route: sum(scores) for route, scores in scores_by_class.items()} + total_scores = { + route: self.aggregation_method(scores) + for route, scores in scores_by_class.items() + } top_class = max(total_scores, key=lambda x: total_scores[x], default=None) # Return the top class and its associated scores diff --git a/semantic_router/layer.py b/semantic_router/layer.py index d0d3e33a346ee6c0f98ccaade0feebe660c2a6b7..221de2bef3f02e5992750fd98d63fd1e02e94659 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -182,6 +182,8 @@ class RouteLayer: llm: Optional[BaseLLM] = None, routes: Optional[List[Route]] = None, index: Optional[BaseIndex] = None, # type: ignore + top_k: int = 5, + aggregation: str = "sum", ): logger.info("local") self.index: BaseIndex = index if index is not None else LocalIndex() @@ -196,6 +198,16 @@ class RouteLayer: self.llm = llm self.routes: list[Route] = routes if routes is not None else [] self.score_threshold = self.encoder.score_threshold + self.top_k = top_k + if self.top_k < 1: + raise ValueError(f"top_k needs to be >= 1, but was: {self.top_k}.") + self.aggregation = aggregation + if self.aggregation not in ["sum", "mean", "max"]: + raise ValueError( + f"Unsupported aggregation method chosen: {aggregation}. Choose either 'SUM', 'MEAN', or 'MAX'." + ) + self.aggregation_method = self._set_aggregation_method(self.aggregation) + # set route score thresholds if not already set for route in self.routes: if route.score_threshold is None: @@ -266,7 +278,7 @@ class RouteLayer: Returns a tuple of the route (if any) and the scores of the top class. """ # get relevant results (scores and routes) - results = self._retrieve(xq=np.array(vector)) + results = self._retrieve(xq=np.array(vector), top_k=self.top_k) # decide most relevant routes top_class, top_class_scores = self._semantic_classify(results) # TODO do we need this check? @@ -391,6 +403,18 @@ class RouteLayer: scores, routes = self.index.query(vector=xq, top_k=top_k) return [{"route": d, "score": s.item()} for d, s in zip(routes, scores)] + def _set_aggregation_method(self, aggregation: str = "sum"): + if aggregation == "sum": + return lambda x: sum(x) + elif aggregation == "mean": + return lambda x: np.mean(x) + elif aggregation == "max": + return lambda x: max(x) + else: + raise ValueError( + f"Unsupported aggregation method chosen: {aggregation}. Choose either 'SUM', 'MEAN', or 'MAX'." + ) + def _semantic_classify(self, query_results: List[dict]) -> Tuple[str, List[float]]: scores_by_class: Dict[str, List[float]] = {} for result in query_results: @@ -402,7 +426,10 @@ class RouteLayer: scores_by_class[route] = [score] # Calculate total score for each class - total_scores = {route: sum(scores) for route, scores in scores_by_class.items()} + total_scores = { + route: self.aggregation_method(scores) + for route, scores in scores_by_class.items() + } top_class = max(total_scores, key=lambda x: total_scores[x], default=None) # Return the top class and its associated scores diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py index d489650903241a81039bfbbcc58d8bdd86ad5842..bf0c2ad2d91aac2b7e2daeae8a595cbbf0a8ae39 100644 --- a/tests/unit/test_hybrid_layer.py +++ b/tests/unit/test_hybrid_layer.py @@ -193,5 +193,48 @@ class TestHybridRouteLayer: assert hybrid_route_layer.sparse_index is not None assert len(hybrid_route_layer.sparse_index) == len(all_utterances) + def test_setting_aggregation_methods(self, openai_encoder, routes): + for agg in ["sum", "mean", "max"]: + route_layer = HybridRouteLayer( + encoder=openai_encoder, + sparse_encoder=sparse_encoder, + routes=routes, + aggregation=agg, + ) + assert route_layer.aggregation == agg + + def test_semantic_classify_multiple_routes_with_different_aggregation( + self, openai_encoder, routes + ): + route_scores = [ + {"route": "Route 1", "score": 0.5}, + {"route": "Route 1", "score": 0.5}, + {"route": "Route 1", "score": 0.5}, + {"route": "Route 1", "score": 0.5}, + {"route": "Route 2", "score": 0.4}, + {"route": "Route 2", "score": 0.6}, + {"route": "Route 2", "score": 0.8}, + {"route": "Route 3", "score": 0.1}, + {"route": "Route 3", "score": 1.0}, + ] + for agg in ["sum", "mean", "max"]: + route_layer = HybridRouteLayer( + encoder=openai_encoder, + sparse_encoder=sparse_encoder, + routes=routes, + aggregation=agg, + ) + classification, score = route_layer._semantic_classify(route_scores) + + if agg == "sum": + assert classification == "Route 1" + assert score == [0.5, 0.5, 0.5, 0.5] + elif agg == "mean": + assert classification == "Route 2" + assert score == [0.4, 0.6, 0.8] + elif agg == "max": + assert classification == "Route 3" + assert score == [0.1, 1.0] + # Add more tests for edge cases and error handling as needed. diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index 415150a597ee396d593c88b651658bd28410fc80..4a55777b66658a4635add189c7402a12fa947932 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -120,9 +120,10 @@ def test_data(): class TestRouteLayer: def test_initialization(self, openai_encoder, routes): - route_layer = RouteLayer(encoder=openai_encoder, routes=routes) + route_layer = RouteLayer(encoder=openai_encoder, routes=routes, top_k=10) assert openai_encoder.score_threshold == 0.82 assert route_layer.score_threshold == 0.82 + assert route_layer.top_k == 10 assert len(route_layer.index) if route_layer.index is not None else 0 == 5 assert ( len(set(route_layer._get_route_names())) @@ -522,3 +523,44 @@ class TestLayerConfig: layer_config = LayerConfig(routes=[route]) layer_config.remove("test") assert layer_config.routes == [] + + def test_setting_aggregation_methods(self, openai_encoder, routes): + for agg in ["sum", "mean", "max"]: + route_layer = RouteLayer( + encoder=openai_encoder, + routes=routes, + aggregation=agg, + ) + assert route_layer.aggregation == agg + + def test_semantic_classify_multiple_routes_with_different_aggregation( + self, openai_encoder, routes + ): + route_scores = [ + {"route": "Route 1", "score": 0.5}, + {"route": "Route 1", "score": 0.5}, + {"route": "Route 1", "score": 0.5}, + {"route": "Route 1", "score": 0.5}, + {"route": "Route 2", "score": 0.4}, + {"route": "Route 2", "score": 0.6}, + {"route": "Route 2", "score": 0.8}, + {"route": "Route 3", "score": 0.1}, + {"route": "Route 3", "score": 1.0}, + ] + for agg in ["sum", "mean", "max"]: + route_layer = RouteLayer( + encoder=openai_encoder, + routes=routes, + aggregation=agg, + ) + classification, score = route_layer._semantic_classify(route_scores) + + if agg == "sum": + assert classification == "Route 1" + assert score == [0.5, 0.5, 0.5, 0.5] + elif agg == "mean": + assert classification == "Route 2" + assert score == [0.4, 0.6, 0.8] + elif agg == "max": + assert classification == "Route 3" + assert score == [0.1, 1.0]