"""
concept_extraction.py
---------------------
Functions to extract, normalize, and process concept relationships from educational content (lesson readings and objectives).
Features:
- Summarizes lesson readings using LLMs and course-specific prompts.
- Extracts key concepts and their relationships, guided by lesson objectives.
- Normalizes and consolidates concept names using embeddings and inflection.
- Prepares structured data for downstream visualization and analysis.
Dependencies:
- Language Models: OpenAI GPT or similar, DistilBERT (transformers)
- Core Libraries: langchain, torch, inflect
Example:
from class_factory.concept_web.concept_extraction import extract_relationships
text = "Democracy relies on voting rights..."
objectives = "Understand principles of democracy"
relationships = extract_relationships(text, objectives, "Political Science", llm)
processed = process_relationships(relationships)
This module is part of the ClassFactory concept mapping pipeline.
"""
# %%
# base libraries
import json
# logger setup
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import inflect
# entity resolution
import torch
import torch.nn.functional as F
# env setup
from dotenv import load_dotenv
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from transformers import DistilBertModel, DistilBertTokenizer
from class_factory.concept_web.prompts import (relationship_prompt,
summary_prompt)
from class_factory.utils.llm_validator import Validator
from class_factory.utils.response_parsers import ExtractedRelations
from class_factory.utils.tools import logger_setup, retry_on_json_decode_error
# logging.basicConfig(
# level=logging.INFO, # Set your desired level
# format='%(name)s - %(levelname)s - %(message)s'
# )
# %%
[docs]
def summarize_text(text: str, prompt: ChatPromptTemplate, course_name: str, llm: Any,
parser: StrOutputParser = StrOutputParser(), verbose: bool = False) -> str:
"""
Summarize the provided text using a language model and a structured prompt.
Args:
text (str): The text to be summarized.
prompt (ChatPromptTemplate): Prompt template for summarization.
course_name (str): Name of the course for context.
llm (Any): Language model instance.
parser (StrOutputParser, optional): Output parser. Defaults to StrOutputParser().
verbose (bool, optional): Enable detailed logging. Defaults to False.
Returns:
str: The summary generated by the language model.
Raises:
ValueError: If validation fails after max retries.
"""
log_level = logging.INFO if verbose else logging.ERROR
logger = logger_setup(log_level=log_level)
chain = prompt | llm | parser
retries, max_retries = 0, 3
valid = False
validator = Validator(llm=llm, log_level=log_level)
additional_guidance = ""
while not valid and retries < max_retries:
summary = chain.invoke({'course_name': course_name,
'text': text,
'additional_guidance': additional_guidance})
logger.info(f"Example summary:\n{summary}")
# Validate the generated summary
validation_prompt = prompt.format(course_name=course_name,
text=text,
additional_guidance=additional_guidance)
# Use the schema from the parser if available
task_schema = parser.pydantic_object.model_json_schema() if hasattr(
parser, 'pydantic_object') and hasattr(parser.pydantic_object, 'model_json_schema') else ""
val_response = validator.validate(
task_description=validation_prompt,
generated_response=summary,
task_schema=task_schema,
specific_guidance=additional_guidance
)
logger.info(f"validation output: {val_response}")
if int(val_response['status']) == 1:
valid = True
else:
retries += 1
additional_guidance = val_response.get("additional_guidance", "")
logger.warning(f"Summary validation failed on attempt {retries}. Reason: {val_response['reasoning']}")
if valid:
logger.debug("Validation succeeded.")
else:
raise ValueError("Validation failed after max retries. Ensure correct prompt and input data. Consider use of a different LLM.")
return summary
[docs]
def get_embeddings(concepts: List[str]) -> Dict[str, torch.Tensor]:
"""
Generate normalized embeddings for a list of concepts using DistilBERT.
Args:
concepts (List[str]): List of concept strings.
Returns:
Dict[str, torch.Tensor]: Mapping of concept to normalized embedding tensor.
"""
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertModel.from_pretrained("distilbert-base-uncased")
# Tokenize and process concepts in batches
inputs = tokenizer(concepts, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
# Use the [CLS] token's embedding
embeddings = outputs.last_hidden_state[:, 0, :] # Shape: (num_concepts, hidden_size)
# Normalize embeddings
normalized_embeddings = F.normalize(embeddings, p=2, dim=1)
# Map concepts directly to embeddings
return {concept: normalized_embeddings[idx] for idx, concept in enumerate(concepts)}
[docs]
def normalize_concept(concept: str) -> str:
"""
Normalize a single concept string (lowercase, singularize, strip underscores).
Args:
concept (str): Concept string.
Returns:
str: Normalized concept string.
"""
p = inflect.engine()
words = concept.lower().strip().replace('_', ' ').split()
normalized_words = [p.singular_noun(word) or word for word in words]
return " ".join(normalized_words)
[docs]
def normalize_for_embedding(concepts: Union[str, List[str]]) -> Union[str, List[str]]:
"""
Normalize one or more concepts for embedding.
Args:
concepts (Union[str, List[str]]): Concept or list of concepts.
Returns:
Union[str, List[str]]: Normalized concept(s).
"""
if isinstance(concepts, str):
return normalize_concept(concepts)
return [normalize_concept(concept) for concept in concepts]
[docs]
def normalize_for_output(concept: str) -> str:
"""
Format a concept for output by replacing spaces with underscores and removing 'is'.
Args:
concept (str): Concept string.
Returns:
str: Output-formatted concept string.
"""
concept_words = [word for word in concept.split() if word != 'is']
return "_".join(concept_words)
[docs]
def replace_similar_concepts(existing_concepts: Set[str], new_concept: str,
concept_embeddings: Dict[str, torch.Tensor],
threshold: float = 0.995) -> str:
"""
Replace a new concept with an existing similar concept if cosine similarity exceeds threshold.
Args:
existing_concepts (Set[str]): Set of existing concepts.
new_concept (str): New concept string.
concept_embeddings (Dict[str, torch.Tensor]): Embeddings for all concepts.
threshold (float, optional): Similarity threshold. Defaults to 0.995.
Returns:
str: Existing or new concept string.
"""
# Get the embedding of the new concept
new_embedding = concept_embeddings[new_concept]
for existing_concept in existing_concepts:
existing_embedding = concept_embeddings[existing_concept]
# Compute cosine similarity
similarity = torch.matmul(new_embedding, existing_embedding).item()
# If similar, return the existing concept
if similarity >= threshold:
return existing_concept
# If no similar concept is found, return the new concept
return new_concept
[docs]
def process_relationships(relationships: List[Tuple[str, str, str]],
threshold: float = 0.995,
max_retries: int = 3) -> List[Tuple[str, str, str]]:
"""
Normalize and consolidate relationships by merging similar concepts.
Args:
relationships (List[Tuple[str, str, str]]): List of (concept1, relationship, concept2) tuples.
threshold (float, optional): Similarity threshold for merging. Defaults to 0.995.
max_retries (int, optional): Max attempts to resolve duplicates. Defaults to 3.
Returns:
List[Tuple[str, str, str]]: Processed relationships with normalized concepts.
"""
# Initialize a set to keep track of all unique concepts
unique_concepts = set()
processed_relationships = []
# Accept both dict and tuple for backward compatibility
if isinstance(relationships[0], dict):
relationships = [
(rel["concept_1"], rel["relationship_type"], rel["concept_2"])
for rel in relationships
if all(k in rel for k in ("concept_1", "relationship_type", "concept_2"))
]
# get concepts and embeddings
extracted_concepts = extract_concepts_from_relationships(relationships)
conceptlist = normalize_for_embedding(extracted_concepts)
concept_embeddings = get_embeddings(conceptlist)
for c1, relationship, c2 in relationships:
c1 = normalize_for_embedding(c1)
c2 = normalize_for_embedding(c2)
# Replace similar concepts with existing ones
retries = 0
concept1 = replace_similar_concepts(unique_concepts, c1, concept_embeddings, threshold)
concept2 = replace_similar_concepts(unique_concepts, c2, concept_embeddings, threshold)
# Retry resolution if concepts are identical. Max threshold increase = 0.0045 (up from default 0.0095)
while concept1 == concept2 and retries < max_retries:
retries += 1
concept1 = replace_similar_concepts(unique_concepts, c1, concept_embeddings, threshold + retries * 0.0015)
concept2 = replace_similar_concepts(unique_concepts, c2, concept_embeddings, threshold + retries * 0.0015)
# # If still identical, skip or revert to original
# if concept1 == concept2:
# continue # Skip self-referential relationships
# Add concepts to the unique set
unique_concepts.add(concept1)
unique_concepts.add(concept2)
# Normalize concepts
clean_concept1 = normalize_for_output(concept1)
clean_concept2 = normalize_for_output(concept2)
clean_relation = normalize_for_output(relationship)
# Add the relationship to the processed list
processed_relationships.append((clean_concept1, clean_relation, clean_concept2))
return processed_relationships
# %%
if __name__ == "__main__":
# llm chain setup
import yaml
from langchain_community.llms import Ollama
from langchain_openai import ChatOpenAI
from pyprojroot.here import here
# self-defined utils
from class_factory.utils.load_documents import LessonLoader
user_home = Path.home()
load_dotenv()
OPENAI_KEY = os.getenv('openai_key')
OPENAI_ORG = os.getenv('openai_org')
# Path definitions
with open("class_config.yaml", "r") as file:
config = yaml.safe_load(file)
class_config = config['PS211']
slide_dir = user_home / class_config['slideDir']
syllabus_path = user_home / class_config['syllabus_path']
readingDir = user_home / class_config['reading_dir']
is_tabular_syllabus = class_config['is_tabular_syllabus']
projectDir = here()
# parser = JsonOutputParser(pydantic_object=ExtractedRelations)
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2,
api_key=OPENAI_KEY,
organization=OPENAI_ORG,
)
# llm = Ollama(
# model="llama3.1",
# temperature=0.5,
# )
relationship_list = []
conceptlist = []
loader = LessonLoader(syllabus_path=syllabus_path,
reading_dir=readingDir,
slide_dir=None)
# Load documents and lesson objectives
for lesson_num in range(3, 4):
print(f"Lesson {lesson_num}")
lesson_objectives = loader.extract_lesson_objectives(current_lesson=lesson_num, only_current=True)
documents = loader.load_lessons(lesson_number_or_range=range(lesson_num, lesson_num + 1))
if not documents:
continue
for lsn, readings in documents.items():
for reading in readings:
summary = summarize_text(reading,
prompt=summary_prompt,
course_name="American government",
llm=llm,
verbose=True)
# print(summary)
relationships = extract_relationships(summary,
lesson_objectives,
course_name="American government",
llm=llm,
verbose=True)
print(relationships)
relationship_list.extend(relationships)
concepts = extract_concepts_from_relationships(relationships)
conceptlist.extend(concepts)
processed_relationships = process_relationships(relationship_list)
# %%