AI Model Integration Guide¶
Overview¶
CulicidaeLab integrates PyTorch Lite models for on-device mosquito species classification. This guide covers the complete AI pipeline, from model loading to inference, including architecture details, performance considerations, and troubleshooting.
Architecture Overview¶
AI Pipeline Components¶
graph TB
subgraph "User Interface"
UI[Classification Screen]
CAM[Camera/Gallery]
end
subgraph "Business Logic"
VM[ClassificationViewModel]
REPO[ClassificationRepository]
end
subgraph "AI Services"
CS[ClassificationService]
PW[PytorchWrapper]
end
subgraph "PyTorch Integration"
PLM[PytorchLite]
CM[ClassificationModel]
API[ModelApi - Native]
end
subgraph "Model Assets"
MODEL[mosquito_classifier.pt]
LABELS[mosquito_species.txt]
end
UI --> VM
CAM --> VM
VM --> REPO
REPO --> CS
CS --> PW
PW --> PLM
PLM --> CM
CM --> API
PLM -.-> MODEL
PLM -.-> LABELS
Data Flow¶
sequenceDiagram
participant User
participant UI as ClassificationScreen
participant VM as ClassificationViewModel
participant CS as ClassificationService
participant PW as PytorchWrapper
participant Model as ClassificationModel
participant Native as Native PyTorch
User->>UI: Capture/Select Image
UI->>VM: classifyImage(imageFile)
VM->>CS: classifyImage(imageFile)
Note over CS: Check if model loaded
alt Model not loaded
CS->>PW: loadClassificationModel()
PW->>Model: Create model instance
Model->>Native: Load .pt file
Native-->>Model: Model loaded
Model-->>PW: ClassificationModel
PW-->>CS: Model ready
end
CS->>Model: getImagePredictionResult(imageBytes)
Model->>Native: Process image + inference
Native-->>Model: Raw predictions
Model->>Model: Apply softmax
Model-->>CS: {label, probability}
CS-->>VM: Classification result
VM-->>UI: Update UI state
UI-->>User: Display results
Model Integration Components¶
1. PytorchWrapper¶
The PytorchWrapper class provides a testable interface around PyTorch Lite operations:
class PytorchWrapper {
/// Loads a classification model from assets
Future<ClassificationModel> loadClassificationModel(
String pathImageModel,
int imageWidth,
int imageHeight, {
String? labelPath,
}) {
return PytorchLite.loadClassificationModel(
pathImageModel,
imageWidth,
imageHeight,
labelPath: labelPath,
);
}
}
Key Features: - Testability: Enables mocking for unit tests - Dependency Injection: Supports service locator pattern - Future Extensibility: Provides hooks for logging and caching
2. ClassificationService¶
The main service orchestrating AI operations:
class ClassificationService {
final PytorchWrapper _pytorchWrapper;
ClassificationModel? _model;
/// Load the mosquito classification model
Future<void> loadModel() async {
String pathImageModel = "assets/models/mosquito_classifier.pt";
try {
_model = await _pytorchWrapper.loadClassificationModel(
pathImageModel, 224, 224,
labelPath: "assets/labels/mosquito_species.txt"
);
} on PlatformException {
throw Exception("Model loading failed - only supported for Android/iOS");
}
}
/// Classify a mosquito image
Future<Map<String, dynamic>> classifyImage(File imageFile) async {
if (_model == null) {
throw Exception("Model not loaded - call loadModel() first");
}
final imageBytes = await imageFile.readAsBytes();
final result = await _model!.getImagePredictionResult(imageBytes);
return {
'scientificName': result['label'].trim(),
'confidence': result['probability'],
};
}
}
Responsibilities: - Model lifecycle management - Image preprocessing coordination - Error handling and validation - Performance monitoring
3. ClassificationModel¶
The PyTorch Lite model wrapper providing inference capabilities:
class ClassificationModel {
final int _index;
final List<String> labels;
/// Get prediction with confidence score
Future<Map<String, dynamic>> getImagePredictionResult(
Uint8List imageAsBytes, {
List<double> mean = TORCHVISION_NORM_MEAN_RGB,
List<double> std = TORCHVISION_NORM_STD_RGB,
}) async {
final List<double?> prediction = await ModelApi()
.getImagePredictionList(_index, imageAsBytes, null, null, null, mean, std);
// Find max prediction
int maxScoreIndex = 0;
for (int i = 1; i < prediction.length; i++) {
if (prediction[i]! > prediction[maxScoreIndex]!) {
maxScoreIndex = i;
}
}
// Apply softmax
double sumExp = 0.0;
for (var element in prediction) {
sumExp += math.exp(element!);
}
final probabilities = prediction
.map((element) => math.exp(element!) / sumExp)
.toList();
return {
"label": labels[maxScoreIndex],
"probability": probabilities[maxScoreIndex]
};
}
}
Model Configuration¶
Model Specifications¶
Current Model Details¶
- File:
assets/models/mosquito_classifier.pt - Input Size: 224x224 pixels
- Color Space: RGB
- Normalization: ImageNet standard (mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225])
- Output: Softmax probabilities for mosquito species classes
Label Configuration¶
- File:
assets/labels/mosquito_species.txt - Format: One species name per line
- Example:
Asset Organization¶
assets/
├── models/
│ └── mosquito_classifier.pt # PyTorch Lite model file
└── labels/
└── mosquito_species.txt # Species label mappings
Image Processing Pipeline¶
1. Image Input Handling¶
// From camera or gallery
final imageFile = File('/path/to/image.jpg');
final imageBytes = await imageFile.readAsBytes();
Supported Formats: - JPEG (.jpg, .jpeg) - PNG (.png) - BMP (.bmp) - WebP (.webp)
2. Preprocessing¶
The PyTorch Lite plugin handles: - Resizing: Automatic resize to 224x224 pixels - Normalization: ImageNet mean/std normalization - Color Space: RGB conversion if needed - Tensor Conversion: Convert to model input format
3. Model Inference¶
// Native inference through ModelApi
final List<double?> rawPrediction = await ModelApi()
.getImagePredictionList(modelIndex, imageBytes, ...);
4. Post-processing¶
// Apply softmax to get probabilities
double sumExp = 0.0;
for (var logit in rawPrediction) {
sumExp += math.exp(logit!);
}
final probabilities = rawPrediction
.map((logit) => math.exp(logit!) / sumExp)
.toList();
// Get top prediction
int maxIndex = probabilities.indexOf(probabilities.reduce(math.max));
String species = labels[maxIndex];
double confidence = probabilities[maxIndex];
Performance Optimization¶
Model Loading Optimization¶
class ClassificationService {
static ClassificationService? _instance;
ClassificationModel? _model;
// Singleton pattern for model reuse
static ClassificationService get instance {
_instance ??= ClassificationService(pytorchWrapper: PytorchWrapper());
return _instance!;
}
// Load model once and cache
Future<void> ensureModelLoaded() async {
if (_model == null) {
await loadModel();
}
}
}
Memory Management¶
// Dispose of large image data promptly
Future<Map<String, dynamic>> classifyImage(File imageFile) async {
Uint8List? imageBytes;
try {
imageBytes = await imageFile.readAsBytes();
final result = await _model!.getImagePredictionResult(imageBytes);
return result;
} finally {
imageBytes = null; // Help GC
}
}
Inference Optimization¶
- Batch Processing: Use batch methods for multiple images
- Image Size: Optimize input image size before processing
- Model Caching: Keep model loaded in memory between predictions
- Background Processing: Run inference on background isolates for heavy workloads
Error Handling¶
Common Error Scenarios¶
1. Platform Support¶
try {
await classificationService.loadModel();
} on PlatformException catch (e) {
// Handle unsupported platform (Web, Desktop)
throw ClassificationException(
'AI classification not supported on this platform',
code: 'PLATFORM_UNSUPPORTED'
);
}
2. Model Loading Failures¶
try {
_model = await _pytorchWrapper.loadClassificationModel(...);
} catch (e) {
if (e.toString().contains('file not found')) {
throw ClassificationException(
'Model file not found in assets',
code: 'MODEL_NOT_FOUND'
);
} else if (e.toString().contains('memory')) {
throw ClassificationException(
'Insufficient memory to load model',
code: 'INSUFFICIENT_MEMORY'
);
}
rethrow;
}
3. Inference Errors¶
Future<Map<String, dynamic>> classifyImage(File imageFile) async {
try {
if (!imageFile.existsSync()) {
throw ClassificationException('Image file not found');
}
final imageBytes = await imageFile.readAsBytes();
if (imageBytes.isEmpty) {
throw ClassificationException('Empty image file');
}
return await _model!.getImagePredictionResult(imageBytes);
} catch (e) {
if (e is ClassificationException) rethrow;
throw ClassificationException('Classification failed: $e');
}
}
Error Recovery Strategies¶
class ClassificationService {
int _retryCount = 0;
static const int maxRetries = 3;
Future<Map<String, dynamic>> classifyImageWithRetry(File imageFile) async {
for (int attempt = 0; attempt <= maxRetries; attempt++) {
try {
return await classifyImage(imageFile);
} catch (e) {
if (attempt == maxRetries) rethrow;
// Exponential backoff
await Future.delayed(Duration(milliseconds: 100 * (1 << attempt)));
// Reload model if needed
if (e.toString().contains('model')) {
_model = null;
await loadModel();
}
}
}
throw ClassificationException('Max retries exceeded');
}
}
Testing AI Integration¶
Unit Testing with Mocks¶
class MockPytorchWrapper extends Mock implements PytorchWrapper {}
class MockClassificationModel extends Mock implements ClassificationModel {}
void main() {
group('ClassificationService', () {
late ClassificationService service;
late MockPytorchWrapper mockWrapper;
late MockClassificationModel mockModel;
setUp(() {
mockWrapper = MockPytorchWrapper();
mockModel = MockClassificationModel();
service = ClassificationService(pytorchWrapper: mockWrapper);
});
test('should load model successfully', () async {
// Arrange
when(mockWrapper.loadClassificationModel(any, any, any, labelPath: any))
.thenAnswer((_) async => mockModel);
// Act
await service.loadModel();
// Assert
expect(service.isModelLoaded, isTrue);
verify(mockWrapper.loadClassificationModel(
'assets/models/mosquito_classifier.pt',
224,
224,
labelPath: 'assets/labels/mosquito_species.txt',
)).called(1);
});
test('should classify image and return result', () async {
// Arrange
when(mockWrapper.loadClassificationModel(any, any, any, labelPath: any))
.thenAnswer((_) async => mockModel);
when(mockModel.getImagePredictionResult(any))
.thenAnswer((_) async => {
'label': 'Aedes aegypti',
'probability': 0.85
});
await service.loadModel();
final imageFile = File('test_image.jpg');
// Act
final result = await service.classifyImage(imageFile);
// Assert
expect(result['scientificName'], equals('Aedes aegypti'));
expect(result['confidence'], equals(0.85));
});
});
}
Integration Testing¶
void main() {
group('AI Model Integration Tests', () {
late ClassificationService service;
setUpAll(() async {
service = ClassificationService(pytorchWrapper: PytorchWrapper());
await service.loadModel();
});
testWidgets('should classify real mosquito image', (tester) async {
// Load test image from assets
final ByteData imageData = await rootBundle.load('test_assets/aedes_aegypti.jpg');
final File tempFile = File('${Directory.systemTemp.path}/test_image.jpg');
await tempFile.writeAsBytes(imageData.buffer.asUint8List());
// Classify image
final result = await service.classifyImage(tempFile);
// Verify results
expect(result['scientificName'], isA<String>());
expect(result['confidence'], isA<double>());
expect(result['confidence'], greaterThan(0.0));
expect(result['confidence'], lessThanOrEqualTo(1.0));
// Clean up
await tempFile.delete();
});
});
}
Model Deployment and Updates¶
Model Versioning¶
class ModelConfig {
static const String currentVersion = 'v1.2.0';
static const String modelPath = 'assets/models/mosquito_classifier_v1_2_0.pt';
static const String labelsPath = 'assets/labels/mosquito_species_v1_2_0.txt';
// Model metadata
static const Map<String, dynamic> modelInfo = {
'version': currentVersion,
'inputSize': [224, 224],
'numClasses': 12,
'accuracy': 0.94,
'trainedOn': '2024-01-15',
};
}
Dynamic Model Loading¶
class ModelManager {
static Future<String> getModelPath() async {
// Check for updated model in documents directory
final documentsDir = await getApplicationDocumentsDirectory();
final updatedModelPath = '${documentsDir.path}/models/latest_model.pt';
if (await File(updatedModelPath).exists()) {
return updatedModelPath;
}
// Fallback to bundled model
return ModelConfig.modelPath;
}
static Future<void> downloadModelUpdate(String downloadUrl) async {
// Implementation for downloading model updates
// Include checksum verification and atomic replacement
}
}
Performance Monitoring¶
Inference Metrics¶
class ClassificationMetrics {
static final Stopwatch _inferenceTimer = Stopwatch();
static final List<Duration> _inferenceTimes = [];
static void startInference() {
_inferenceTimer.reset();
_inferenceTimer.start();
}
static void endInference() {
_inferenceTimer.stop();
_inferenceTimes.add(_inferenceTimer.elapsed);
// Keep only last 100 measurements
if (_inferenceTimes.length > 100) {
_inferenceTimes.removeAt(0);
}
}
static Duration get averageInferenceTime {
if (_inferenceTimes.isEmpty) return Duration.zero;
final totalMs = _inferenceTimes
.map((d) => d.inMilliseconds)
.reduce((a, b) => a + b);
return Duration(milliseconds: totalMs ~/ _inferenceTimes.length);
}
}
Memory Usage Tracking¶
class MemoryMonitor {
static Future<void> logMemoryUsage(String operation) async {
final info = await DeviceInfoPlugin().androidInfo;
// Log memory usage for performance analysis
print('Memory usage during $operation: ${info.totalMemory}');
}
}
Troubleshooting Guide¶
Common Issues¶
1. Model Loading Fails¶
Symptoms: Exception during loadModel()
Causes:
- Missing model file in assets
- Corrupted model file
- Insufficient memory
- Unsupported platform
Solutions:
// Verify asset exists
final ByteData modelData = await rootBundle.load('assets/models/mosquito_classifier.pt');
print('Model size: ${modelData.lengthInBytes} bytes');
// Check available memory
final info = await DeviceInfoPlugin().androidInfo;
print('Available memory: ${info.totalMemory}');
2. Poor Classification Accuracy¶
Symptoms: Low confidence scores or incorrect predictions Causes: - Poor image quality - Incorrect preprocessing - Model-data mismatch
Solutions:
// Validate image quality
Future<bool> validateImageQuality(File imageFile) async {
final image = img.decodeImage(await imageFile.readAsBytes());
if (image == null) return false;
// Check minimum resolution
if (image.width < 224 || image.height < 224) return false;
// Check for blur (simplified)
// Implement blur detection algorithm
return true;
}
3. Slow Inference Performance¶
Symptoms: Long classification times Causes: - Large input images - Memory pressure - CPU throttling
Solutions:
// Optimize image size
Future<File> optimizeImageForInference(File originalImage) async {
final image = img.decodeImage(await originalImage.readAsBytes());
if (image == null) throw Exception('Invalid image');
// Resize if too large
final resized = img.copyResize(image, width: 512, height: 512);
// Compress
final compressed = img.encodeJpg(resized, quality: 85);
final optimizedFile = File('${originalImage.path}_optimized.jpg');
await optimizedFile.writeAsBytes(compressed);
return optimizedFile;
}
Debug Mode Features¶
class ClassificationService {
static const bool debugMode = kDebugMode;
Future<Map<String, dynamic>> classifyImage(File imageFile) async {
if (debugMode) {
print('Classifying image: ${imageFile.path}');
print('Image size: ${await imageFile.length()} bytes');
}
final stopwatch = Stopwatch()..start();
final result = await _performClassification(imageFile);
stopwatch.stop();
if (debugMode) {
print('Classification took: ${stopwatch.elapsedMilliseconds}ms');
print('Result: ${result['scientificName']} (${result['confidence']})');
}
return result;
}
}
Future Enhancements¶
Planned Improvements¶
- Model Quantization: Reduce model size and improve inference speed
- Multi-Model Support: Support for different model architectures
- Edge TPU Integration: Hardware acceleration on supported devices
- Federated Learning: Contribute to model improvement while preserving privacy
- Real-time Classification: Video stream classification capabilities
Extension Points¶
abstract class ClassificationProvider {
Future<ClassificationResult> classify(File imageFile);
bool get isAvailable;
String get providerName;
}
class PyTorchClassificationProvider implements ClassificationProvider {
// Current implementation
}
class TensorFlowLiteProvider implements ClassificationProvider {
// Alternative implementation
}
class CloudMLProvider implements ClassificationProvider {
// Cloud-based classification
}
Conclusion¶
The AI model integration in CulicidaeLab provides robust, on-device mosquito species classification through PyTorch Lite. The architecture supports testing, performance monitoring, and future enhancements while maintaining reliability and user experience quality.
Key benefits: - Privacy: On-device processing keeps user data local - Performance: Optimized for mobile inference - Reliability: Comprehensive error handling and recovery - Maintainability: Clean architecture with dependency injection - Testability: Mockable interfaces for comprehensive testing
For additional support or questions about AI model integration, refer to the PyTorch Mobile documentation or create an issue in the project repository.