機器學習模型的挑戰
機器學習模型在預測某些個體時可能會失敗,特別是當這些個體在訓練數據中代表性不足時。
例如,一個預測慢性病患者最佳治療方案的模型,可能是用主要包含男性患者的數據集訓練的。當這個模型在醫院使用時,可能會對女性患者做出錯誤的預測。
改善預測結果的方法
為了改善結果,工程師可以嘗試通過刪除數據點來平衡訓練數據集,直到所有子群體的代表性相等。雖然數據集平衡是有希望的,但這通常需要刪除大量數據,這會影響模型的整體表現。
麻省理工學院 (MIT) 的研究人員開發了一種新技術,可以識別並刪除訓練數據集中對少數群體失敗影響最大的特定數據點。通過刪除比其他方法少得多的數據點,這項技術在提升模型對於代表性不足群體的表現的同時,保持了模型的整體準確性。
發現隱藏的偏見
此外,這項技術還可以識別缺乏標籤的訓練數據集中的隱藏偏見來源。對於許多應用來說,未標記的數據比標記的數據更為普遍。
這種方法還可以與其他方法結合,改善在高風險情況下部署的機器學習模型的公平性。例如,未來它可能幫助確保代表性不足的患者不會因為偏見的人工智慧 (AI) 模型而被誤診。
研究團隊的見解
麻省理工學院的電機工程與計算機科學 (EECS) 研究生 Kimia Hamidieh 說:“許多其他試圖解決這個問題的算法假設每個數據點的重要性相同。但我們的研究顯示這個假設並不正確。我們可以找到那些導致偏見的特定數據點,刪除它們,並獲得更好的表現。”
她與共同作者 Saachi Jain 博士、EECS 研究生 Kristian Georgiev、斯坦福大學的 Andrew Ilyas、以及資深作者 Marzyeh Ghassemi 和 Aleksander Madry 共同撰寫了這篇論文。這項研究將在神經信息處理系統會議上發表。
刪除不良範例
通常,機器學習模型是使用來自互聯網多個來源的大型數據集進行訓練的。這些數據集太大,無法手動仔細篩選,因此可能包含會影響模型表現的不良範例。
科學家們也知道某些數據點對模型在某些任務上的表現影響比其他數據點更大。
麻省理工學院的研究人員將這兩個想法結合起來,提出了一種識別和刪除這些問題數據點的方法。他們試圖解決一個稱為最差群體錯誤的問題,這種情況發生在模型在訓練數據集中的少數群體表現不佳時。
新技術的優勢
研究人員的新技術基於他們之前的工作,該工作引入了一種名為 TRAK 的方法,用於識別特定模型輸出的最重要訓練範例。
對於這項新技術,他們分析模型對少數群體的錯誤預測,並使用 TRAK 來識別哪些訓練範例對這些錯誤預測貢獻最大。
Ilyas 解釋說:“通過以正確的方式聚合這些不良測試預測的信息,我們能夠找到影響整體最差群體準確性的訓練部分。”
然後,他們刪除這些特定樣本,並在剩餘數據上重新訓練模型。
由於擁有更多數據通常會提高整體表現,僅刪除導致最差群體失敗的樣本可以保持模型的整體準確性,同時提升其在少數群體上的表現。
更易於使用的方法
在三個機器學習數據集中,他們的方法表現超過了多種技術。在一個例子中,它在刪除約 20,000 個訓練樣本的情況下,提高了最差群體的準確性,這比傳統的數據平衡方法少得多。這項技術的準確性也高於需要改變模型內部運作的方法。
因為麻省理工學院的方法涉及改變數據集,所以對於實踐者來說更容易使用,並且可以應用於多種類型的模型。
當偏見未知時,這項技術也可以被利用,因為訓練數據集中的子群體沒有標籤。通過識別對模型學習的特徵貢獻最大的數據點,他們可以理解模型用來做預測的變數。
Hamidieh 說:“這是一個任何人在訓練機器學習模型時都可以使用的工具。他們可以查看這些數據點,看看它們是否與他們想要教給模型的能力相符。”
使用這項技術來檢測未知的子群體偏見需要對要查找的群體有直覺,因此研究人員希望通過未來的人類研究來驗證並更全面地探索這一點。
他們還希望提高技術的性能和可靠性,並確保這種方法對於未來可能在現實環境中部署的實踐者來說是可及且易於使用的。
Ilyas 說:“當你擁有能讓你批判性地查看數據的工具,並找出哪些數據點會導致偏見或其他不良行為時,這為建立更公平和更可靠的模型邁出了第一步。”
這項工作部分由美國國家科學基金會和美國國防高級研究計畫局資助。
新聞來源
本文由 AI 台灣 使用 AI 編撰,內容僅供參考,請自行進行事實查核。加入 AI TAIWAN Google News,隨時掌握最新 AI 資訊!