vsp-demo / tests /test_harness /enum_classifier_test.py
navkast
Fix all references in main and workbooks (#11)
518f864 unverified
import asyncio
import json
from pathlib import Path
from typing import Any
from pydantic import BaseModel
from vsp.app.main import VspDataEnrichment
from vsp.app.model.linkedin.linkedin_models import LinkedinProfile
from vsp.app.model.vsp.vsp_models import VSPProfile
from vsp.shared import logger_factory
logger = logger_factory.get_logger(__name__)
class ComparisonResult(BaseModel):
profile_name: str
correct_enums: int
total_enums: int
accuracy: float
additional_enums_in_actual: int
comparisons: dict[str, Any]
def load_profiles() -> dict[str, tuple[LinkedinProfile, VSPProfile]]:
sample_profiles_dir = Path("tests/test_data/sample_profiles")
classified_profiles_dir = Path("tests/test_data/sample_profiles_classified")
profiles = {}
for profile_file in sample_profiles_dir.glob("*.json"):
name = profile_file.stem
with profile_file.open() as f:
linkedin_data = json.load(f)
linkedin_profile = LinkedinProfile.model_validate(linkedin_data)
classified_file = classified_profiles_dir / f"{name}.json"
if classified_file.exists():
with classified_file.open() as f:
classified_data = json.load(f)
classified_profile = VSPProfile.model_validate(classified_data)
profiles[name] = (linkedin_profile, classified_profile)
return profiles
async def compare_profiles(linkedin_profile: LinkedinProfile, classified_profile: VSPProfile) -> ComparisonResult:
vsp_enrichment = VspDataEnrichment()
result = await vsp_enrichment.process_linkedin_profile(linkedin_profile)
comparisons = {}
correct_enums = 0
total_enums = 0
additional_enums_in_actual = 0
# Compare educations
for classified_edu, result_edu in zip(classified_profile.education.education_history, result.classified_educations):
comparisons[f"Education: {classified_edu.school}"] = {
"expected": classified_edu.degree_characterization,
"actual": result_edu.classification.output.value,
"confidence": result_edu.classification.confidence,
"reasoning": result_edu.classification.reasoning,
}
if classified_edu.degree_characterization is not None:
total_enums += 1
if classified_edu.degree_characterization == result_edu.classification.output.value:
correct_enums += 1
elif result_edu.classification.output.value is not None:
additional_enums_in_actual += 1
# Compare work experiences
for classified_exp, result_exp in zip(
classified_profile.professional_experience.experience_history, result.classified_work_experiences
):
exp_key = f"Job: {classified_exp.title} at {classified_exp.company}"
comparisons[exp_key] = {
"primary_job_type": {
"expected": classified_exp.primary_job_type,
"actual": result_exp.work_experience_classification.primary_job_type.value,
"confidence": result_exp.work_experience_classification.confidence,
"reasoning": result_exp.work_experience_classification.reasoning,
},
"secondary_job_type": {
"expected": classified_exp.secondary_job_type,
"actual": result_exp.work_experience_classification.secondary_job_type.value,
"confidence": result_exp.work_experience_classification.confidence,
"reasoning": result_exp.work_experience_classification.reasoning,
},
}
if classified_exp.primary_job_type is not None:
total_enums += 1
if classified_exp.primary_job_type == result_exp.work_experience_classification.primary_job_type.value:
correct_enums += 1
elif result_exp.work_experience_classification.primary_job_type.value is not None:
additional_enums_in_actual += 1
if classified_exp.secondary_job_type is not None:
total_enums += 1
if classified_exp.secondary_job_type == result_exp.work_experience_classification.secondary_job_type.value:
correct_enums += 1
elif result_exp.work_experience_classification.secondary_job_type.value is not None:
additional_enums_in_actual += 1
# Add comparisons for investment banking, investing focus, etc. if available
if result_exp.investment_banking_classification:
comparisons[exp_key]["investment_banking_group"] = {
"expected": classified_exp.investment_banking_focus,
"actual": result_exp.investment_banking_classification.investment_banking_group.value,
"confidence": result_exp.investment_banking_classification.confidence,
"reasoning": result_exp.investment_banking_classification.reasoning,
}
if classified_exp.investment_banking_focus is not None:
total_enums += 1
if (
classified_exp.investment_banking_focus
== result_exp.investment_banking_classification.investment_banking_group.value
):
correct_enums += 1
elif result_exp.investment_banking_classification.investment_banking_group.value is not None:
additional_enums_in_actual += 1
if result_exp.investing_focus_asset_class_classification:
comparisons[exp_key]["investing_focus_asset_class"] = {
"expected": classified_exp.investing_focus_stage,
"actual": result_exp.investing_focus_asset_class_classification.investing_focus_asset_class.value,
"confidence": result_exp.investing_focus_asset_class_classification.confidence,
"reasoning": result_exp.investing_focus_asset_class_classification.reasoning,
}
if classified_exp.investing_focus_stage is not None:
total_enums += 1
if (
classified_exp.investing_focus_stage
== result_exp.investing_focus_asset_class_classification.investing_focus_asset_class.value
):
correct_enums += 1
elif result_exp.investing_focus_asset_class_classification.investing_focus_asset_class.value is not None:
additional_enums_in_actual += 1
if result_exp.investing_focus_sector_classification:
comparisons[exp_key]["investing_focus_sector"] = {
"expected": classified_exp.investing_focus_sector,
"actual": result_exp.investing_focus_sector_classification.investing_focus_sector.value,
"confidence": result_exp.investing_focus_sector_classification.confidence,
"reasoning": result_exp.investing_focus_sector_classification.reasoning,
}
if classified_exp.investing_focus_sector is not None:
total_enums += 1
if (
classified_exp.investing_focus_sector
== result_exp.investing_focus_sector_classification.investing_focus_sector.value
):
correct_enums += 1
elif result_exp.investing_focus_sector_classification.investing_focus_sector.value is not None:
additional_enums_in_actual += 1
accuracy = correct_enums / total_enums if total_enums > 0 else 0
return ComparisonResult(
profile_name=f"{linkedin_profile.first_name} {linkedin_profile.last_name}",
correct_enums=correct_enums,
total_enums=total_enums,
accuracy=accuracy,
additional_enums_in_actual=additional_enums_in_actual,
comparisons=comparisons,
)
async def run_tests() -> None:
profiles = load_profiles()
results = []
for _, (linkedin_profile, classified_profile) in profiles.items():
result = await compare_profiles(linkedin_profile, classified_profile)
results.append(result)
logger.info(
f"Processed {result.profile_name}: Accuracy {result.accuracy:.2%}, "
f"Additional enums in actual: {result.additional_enums_in_actual}"
)
overall_accuracy = sum(r.accuracy for r in results) / len(results)
total_additional_enums = sum(r.additional_enums_in_actual for r in results)
logger.info(f"Overall accuracy: {overall_accuracy:.2%}")
logger.info(f"Total additional enums in actual: {total_additional_enums}")
# Save detailed results to a JSON file
with open("enum_classifier_results.json", "w") as f:
json.dump([r.model_dump() for r in results], f, indent=2)
logger.info("Detailed results saved to enum_classifier_results.json")
if __name__ == "__main__":
asyncio.run(run_tests())