Skip to content

Commit fc7e3f7

Browse files
feat: Implement device detection and model optimization
This commit introduces device detection for webgpu and wasm support, optimizes model loading and prediction, and implements feature updates to enhance the predicted output.
1 parent 589d15e commit fc7e3f7

File tree

2 files changed

+114
-33
lines changed

2 files changed

+114
-33
lines changed

src/components/symptoms/SymptomChecker.tsx

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { useState } from "react";
1+
import { useState, useEffect } from "react";
22
import { Button } from "@/components/ui/button";
33
import { Card, CardContent, CardDescription, CardFooter, CardHeader, CardTitle } from "@/components/ui/card";
44
import { useToast } from "@/components/ui/use-toast";
@@ -19,6 +19,25 @@ const SymptomChecker = () => {
1919
const [isLoading, setIsLoading] = useState(false);
2020
const [results, setResults] = useState<any | null>(null);
2121
const [showAllSymptoms, setShowAllSymptoms] = useState(false);
22+
const [isModelLoading, setIsModelLoading] = useState(true);
23+
24+
useEffect(() => {
25+
const initModel = async () => {
26+
try {
27+
await initializeModel();
28+
} catch (error) {
29+
toast({
30+
title: "Model Initialization Error",
31+
description: "Error loading the analysis model. Please try again.",
32+
variant: "destructive",
33+
});
34+
} finally {
35+
setIsModelLoading(false);
36+
}
37+
};
38+
39+
initModel();
40+
}, []);
2241

2342
const filteredSymptoms = searchTerm
2443
? searchSymptoms(searchTerm)
@@ -135,38 +154,48 @@ const SymptomChecker = () => {
135154
<CardHeader>
136155
<CardTitle className="text-2xl">Symptom Checker</CardTitle>
137156
<CardDescription>
138-
Add your symptoms and provide additional information for a more accurate analysis.
157+
{isModelLoading
158+
? "Initializing analysis model..."
159+
: "Add your symptoms and provide additional information for a more accurate analysis."}
139160
</CardDescription>
140161
</CardHeader>
141162

142163
<CardContent className="space-y-6">
143-
<SymptomSearchInput
144-
searchTerm={searchTerm}
145-
setSearchTerm={setSearchTerm}
146-
filteredSymptoms={filteredSymptoms}
147-
onAddSymptom={handleAddSymptom}
148-
onVoiceInput={handleVoiceInput}
149-
showAllSymptoms={showAllSymptoms}
150-
setShowAllSymptoms={setShowAllSymptoms}
151-
/>
152-
<SelectedSymptomsList
153-
selectedSymptoms={selectedSymptoms}
154-
onRemoveSymptom={handleRemoveSymptom}
155-
/>
156-
<AdditionalInfoTextarea
157-
value={additionalInfo}
158-
onChange={setAdditionalInfo}
159-
/>
164+
{isModelLoading ? (
165+
<div className="flex items-center justify-center py-8">
166+
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-revon-primary"></div>
167+
</div>
168+
) : (
169+
<>
170+
<SymptomSearchInput
171+
searchTerm={searchTerm}
172+
setSearchTerm={setSearchTerm}
173+
filteredSymptoms={filteredSymptoms}
174+
onAddSymptom={handleAddSymptom}
175+
onVoiceInput={handleVoiceInput}
176+
showAllSymptoms={showAllSymptoms}
177+
setShowAllSymptoms={setShowAllSymptoms}
178+
/>
179+
<SelectedSymptomsList
180+
selectedSymptoms={selectedSymptoms}
181+
onRemoveSymptom={handleRemoveSymptom}
182+
/>
183+
<AdditionalInfoTextarea
184+
value={additionalInfo}
185+
onChange={setAdditionalInfo}
186+
/>
187+
</>
188+
)}
160189
</CardContent>
161190

162191
<CardFooter>
163192
<Button
164193
className="w-full gradient-btn"
165194
onClick={handleAnalyze}
166-
disabled={isLoading}
195+
disabled={isLoading || isModelLoading}
167196
>
168-
{isLoading ? "Analyzing..." : "Analyze Symptoms"}
169-
{!isLoading && <ArrowRight className="ml-2 h-4 w-4" />}
197+
{isModelLoading ? "Initializing..." : isLoading ? "Analyzing..." : "Analyze Symptoms"}
198+
{!isLoading && !isModelLoading && <ArrowRight className="ml-2 h-4 w-4" />}
170199
</Button>
171200
</CardFooter>
172201
</Card>

src/utils/diseasePredictor.ts

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,88 @@
11

2-
import { pipeline } from '@huggingface/transformers';
2+
import { pipeline, env } from '@huggingface/transformers';
33
import { DISEASES } from '@/data/diseases';
44

55
let classifier: any = null;
6+
let deviceType: 'cpu' | 'webgpu' = 'cpu';
7+
8+
// Check for WebGPU support
9+
const checkWebGPUSupport = async (): Promise<boolean> => {
10+
if (!navigator.gpu) {
11+
console.log('WebGPU is not supported in this browser');
12+
return false;
13+
}
14+
15+
try {
16+
const adapter = await navigator.gpu.requestAdapter();
17+
if (!adapter) {
18+
console.log('No WebGPU adapter found');
19+
return false;
20+
}
21+
console.log('WebGPU is supported');
22+
return true;
23+
} catch (error) {
24+
console.log('Error checking WebGPU support:', error);
25+
return false;
26+
}
27+
};
628

729
export const initializeModel = async () => {
830
if (!classifier) {
9-
classifier = await pipeline(
10-
'zero-shot-classification',
11-
'facebook/bart-large-mnli',
12-
{ device: 'cpu' }
13-
);
31+
try {
32+
// Check for WebGPU support
33+
const hasWebGPU = await checkWebGPUSupport();
34+
deviceType = hasWebGPU ? 'webgpu' : 'cpu';
35+
36+
console.log(`Initializing model on ${deviceType}`);
37+
38+
// Configure transformers.js
39+
env.useBrowserCache = true;
40+
env.allowLocalModels = false;
41+
42+
classifier = await pipeline(
43+
'zero-shot-classification',
44+
'facebook/bart-large-mnli',
45+
{ device: deviceType }
46+
);
47+
48+
console.log('Model initialized successfully');
49+
} catch (error) {
50+
console.error('Error initializing model:', error);
51+
// Fallback to CPU if WebGPU initialization fails
52+
if (deviceType === 'webgpu') {
53+
console.log('Falling back to CPU');
54+
deviceType = 'cpu';
55+
classifier = await pipeline(
56+
'zero-shot-classification',
57+
'facebook/bart-large-mnli',
58+
{ device: 'cpu' }
59+
);
60+
} else {
61+
throw error;
62+
}
63+
}
1464
}
1565
return classifier;
1666
};
1767

1868
export const predictDisease = async (symptoms: string[]) => {
19-
// If no symptoms provided, return early
2069
if (!symptoms.length) {
2170
throw new Error("Please select at least one symptom for analysis");
2271
}
2372

73+
console.log(`Running prediction on ${deviceType}`);
2474
const model = await initializeModel();
2575

26-
// Convert symptoms array to a text description
2776
const symptomText = symptoms.join(', ');
77+
console.log('Processing symptoms:', symptomText);
2878

29-
// Get disease names as candidate labels
3079
const candidateLabels = DISEASES.map(disease => disease.name);
3180

3281
try {
82+
console.log('Starting prediction...');
3383
const result = await model(symptomText, candidateLabels);
84+
console.log('Raw prediction results:', result);
3485

35-
// Filter predictions with score > 20%
3686
const predictions = result.labels
3787
.map((label: string, index: number) => ({
3888
name: label,
@@ -42,14 +92,16 @@ export const predictDisease = async (symptoms: string[]) => {
4292
}))
4393
.filter(pred => pred.probability > 20)
4494
.sort((a, b) => b.probability - a.probability)
45-
.slice(0, 3); // Top 3 predictions
95+
.slice(0, 3);
4696

4797
if (predictions.length === 0) {
4898
throw new Error("No strong matches found for the provided symptoms");
4999
}
50100

101+
console.log('Final predictions:', predictions);
51102
return predictions;
52103
} catch (error) {
104+
console.error('Prediction error:', error);
53105
if (error instanceof Error) {
54106
throw error;
55107
}

0 commit comments

Comments
 (0)