11
2- import { pipeline } from '@huggingface/transformers' ;
2+ import { pipeline , env } from '@huggingface/transformers' ;
33import { DISEASES } from '@/data/diseases' ;
44
55let classifier : any = null ;
6+ let deviceType : 'cpu' | 'webgpu' = 'cpu' ;
7+
8+ // Check for WebGPU support
9+ const checkWebGPUSupport = async ( ) : Promise < boolean > => {
10+ if ( ! navigator . gpu ) {
11+ console . log ( 'WebGPU is not supported in this browser' ) ;
12+ return false ;
13+ }
14+
15+ try {
16+ const adapter = await navigator . gpu . requestAdapter ( ) ;
17+ if ( ! adapter ) {
18+ console . log ( 'No WebGPU adapter found' ) ;
19+ return false ;
20+ }
21+ console . log ( 'WebGPU is supported' ) ;
22+ return true ;
23+ } catch ( error ) {
24+ console . log ( 'Error checking WebGPU support:' , error ) ;
25+ return false ;
26+ }
27+ } ;
628
729export const initializeModel = async ( ) => {
830 if ( ! classifier ) {
9- classifier = await pipeline (
10- 'zero-shot-classification' ,
11- 'facebook/bart-large-mnli' ,
12- { device : 'cpu' }
13- ) ;
31+ try {
32+ // Check for WebGPU support
33+ const hasWebGPU = await checkWebGPUSupport ( ) ;
34+ deviceType = hasWebGPU ? 'webgpu' : 'cpu' ;
35+
36+ console . log ( `Initializing model on ${ deviceType } ` ) ;
37+
38+ // Configure transformers.js
39+ env . useBrowserCache = true ;
40+ env . allowLocalModels = false ;
41+
42+ classifier = await pipeline (
43+ 'zero-shot-classification' ,
44+ 'facebook/bart-large-mnli' ,
45+ { device : deviceType }
46+ ) ;
47+
48+ console . log ( 'Model initialized successfully' ) ;
49+ } catch ( error ) {
50+ console . error ( 'Error initializing model:' , error ) ;
51+ // Fallback to CPU if WebGPU initialization fails
52+ if ( deviceType === 'webgpu' ) {
53+ console . log ( 'Falling back to CPU' ) ;
54+ deviceType = 'cpu' ;
55+ classifier = await pipeline (
56+ 'zero-shot-classification' ,
57+ 'facebook/bart-large-mnli' ,
58+ { device : 'cpu' }
59+ ) ;
60+ } else {
61+ throw error ;
62+ }
63+ }
1464 }
1565 return classifier ;
1666} ;
1767
1868export const predictDisease = async ( symptoms : string [ ] ) => {
19- // If no symptoms provided, return early
2069 if ( ! symptoms . length ) {
2170 throw new Error ( "Please select at least one symptom for analysis" ) ;
2271 }
2372
73+ console . log ( `Running prediction on ${ deviceType } ` ) ;
2474 const model = await initializeModel ( ) ;
2575
26- // Convert symptoms array to a text description
2776 const symptomText = symptoms . join ( ', ' ) ;
77+ console . log ( 'Processing symptoms:' , symptomText ) ;
2878
29- // Get disease names as candidate labels
3079 const candidateLabels = DISEASES . map ( disease => disease . name ) ;
3180
3281 try {
82+ console . log ( 'Starting prediction...' ) ;
3383 const result = await model ( symptomText , candidateLabels ) ;
84+ console . log ( 'Raw prediction results:' , result ) ;
3485
35- // Filter predictions with score > 20%
3686 const predictions = result . labels
3787 . map ( ( label : string , index : number ) => ( {
3888 name : label ,
@@ -42,14 +92,16 @@ export const predictDisease = async (symptoms: string[]) => {
4292 } ) )
4393 . filter ( pred => pred . probability > 20 )
4494 . sort ( ( a , b ) => b . probability - a . probability )
45- . slice ( 0 , 3 ) ; // Top 3 predictions
95+ . slice ( 0 , 3 ) ;
4696
4797 if ( predictions . length === 0 ) {
4898 throw new Error ( "No strong matches found for the provided symptoms" ) ;
4999 }
50100
101+ console . log ( 'Final predictions:' , predictions ) ;
51102 return predictions ;
52103 } catch ( error ) {
104+ console . error ( 'Prediction error:' , error ) ;
53105 if ( error instanceof Error ) {
54106 throw error ;
55107 }
0 commit comments