import pandas as pd
import random
from sklearn.tree import DecisionTreeClassifier, export_text, plot_tree
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
# Step 1: Define Patient Data Input
class PatientData:
def __init__(self, presenting_complaint, prakrati, patient_history, investigations, genome_analysis,
symptoms_duration, lifestyle, allergies, symptoms):
self.presenting_complaint = presenting_complaint
self.prakrati = prakrati
self.patient_history = patient_history
self.investigations = investigations
self.genome_analysis = genome_analysis
self.symptoms_duration = symptoms_duration
self.lifestyle = lifestyle
self.allergies = allergies
self.symptoms = symptoms # List of symptoms
# Step 2: Dynamic Dataset Creation for Tree Training
class DatasetGenerator:
def __init__(self):
# Generate a pool of 600 symptoms
self.symptoms_pool = [f"symptom_{i}" for i in range(1, 601)]
self.columns = ['Presenting Complaint', 'Prakrati', 'History', 'Investigation', 'Genome',
'Symptoms Duration', 'Lifestyle', 'Allergies', 'Symptoms', 'Treatment Pathway']
self.data = pd.DataFrame({
'Presenting Complaint': ['diabetes', 'hypertension', 'arthritis', 'headache'],
'Prakrati': ['kapha', 'vata', 'vata', 'pitta'],
'History': ['obesity', 'stress', 'joint pain', 'migraine'],
'Investigation': ['blood sugar', 'BP readings', 'x-ray', 'neuro scan'],
'Genome': ['type 2 diabetes risk', 'stress gene', 'joint weakness', 'migraine tendency'],
'Symptoms Duration': ['chronic', 'acute', 'chronic', 'episodic'],
'Lifestyle': ['sedentary', 'active', 'sedentary', 'moderate'],
'Allergies': ['none', 'pollen', 'NSAIDs', 'none'],
'Symptoms': [random.sample(self.symptoms_pool, 5) for _ in range(4)], # Randomly pick 5 symptoms
'Treatment Pathway': ['metformin, ashwagandha', 'amlodipine, yoga', 'NSAIDs, guggulu', 'paracetamol, brahmi']
})
def add_patient(self, presenting_complaint, prakrati, history, investigation, genome,
symptoms_duration, lifestyle, allergies, symptoms, treatment_pathway):
new_data = {
'Presenting Complaint': presenting_complaint,
'Prakrati': prakrati,
'History': history,
'Investigation': investigation,
'Genome': genome,
'Symptoms Duration': symptoms_duration,
'Lifestyle': lifestyle,
'Allergies': allergies,
'Symptoms': symptoms,
'Treatment Pathway': treatment_pathway
}
self.data = pd.concat([self.data, pd.DataFrame([new_data])], ignore_index=True)
def get_dataset(self):
return self.data
# Step 3: Decision Tree Model
class TreatmentDecisionTree:
def __init__(self, dataset):
self.dataset = dataset
self.model = None
self.encoder = {}
self.feature_columns = self.dataset.columns[:-1] # All columns except the target
self.target_column = 'Treatment Pathway'
self.preprocessed_data = None
def preprocess_data(self):
# Flatten symptoms column for encoding
encoded_data = self.dataset.copy()
encoded_data['Symptoms'] = encoded_data['Symptoms'].apply(lambda x: ' '.join(x))
for column in self.feature_columns:
self.encoder[column] = LabelEncoder()
encoded_data[column] = self.encoder[column].fit_transform(encoded_data[column])
# Split features and target
X = encoded_data[self.feature_columns]
y = encoded_data[self.target_column]
self.preprocessed_data = (X, y)
return X, y
def train_tree(self, max_depth=None, criterion='gini'):
X, y = self.preprocess_data()
self.model = DecisionTreeClassifier(max_depth=max_depth, criterion=criterion)
self.model.fit(X, y)
print("Decision Tree trained successfully.")
return self.model
def visualize_tree(self):
if not self.model:
raise Exception("Model not trained yet.")
# Generate a textual representation of the tree
feature_names = list(self.feature_columns)
tree_rules = export_text(self.model, feature_names=feature_names)
print("Decision Tree Rules:")
print(tree_rules)
# Plot the tree
plt.figure(figsize=(20, 10))
plot_tree(self.model, feature_names=feature_names, class_names=self.dataset[self.target_column].unique(),
filled=True, rounded=True)
plt.title("Treatment Decision Tree")
plt.show()
def predict(self, patient_data):
if not self.model:
raise Exception("Model not trained yet.")
# Encode patient data for prediction
symptoms_encoded = ' '.join(patient_data.symptoms) # Join symptoms for encoding
patient_input = [
self.encoder[col].transform([getattr(patient_data, col.lower().replace(" ", "_")) if col != 'Symptoms' else symptoms_encoded])[0]
for col in self.feature_columns
]
prediction = self.model.predict([patient_input])[0]
# Decode the treatment pathway
return self.dataset[self.target_column].iloc[prediction]
# Step 4: Main Application Logic
def main():
# Initialize dataset
dataset_gen = DatasetGenerator()
# Add a few more patients dynamically
dataset_gen.add_patient(
presenting_complaint="asthma",
prakrati="kapha",
history="chronic wheezing",
investigation="spirometry",
genome="allergic gene markers",
symptoms_duration="chronic",
lifestyle="moderate",
allergies="dust",
symptoms=random.sample(dataset_gen.symptoms_pool, 5), # Random 5 symptoms
treatment_pathway="inhalers, turmeric therapy"
)
dataset = dataset_gen.get_dataset()
# Train the decision tree
tree = TreatmentDecisionTree(dataset)
tree.train_tree(max_depth=10)
# Visualize the tree
tree.visualize_tree()
# Make a prediction for a new patient
new_patient = PatientData(
presenting_complaint="arthritis",
prakrati="vata",
patient_history="joint pain",
investigations="x-ray",
genome_analysis="joint weakness",
symptoms_duration="chronic",
lifestyle="sedentary",
allergies="NSAIDs",
symptoms=random.sample(dataset_gen.symptoms_pool, 5) # Random 5 symptoms
)
recommended_treatment = tree.predict(new_patient)
print(f"Recommended Treatment Pathway for the patient: {recommended_treatment}")
if __name__ == "__main__":
main()