Untitled

mail@pastecode.io avatar
unknown
python
a month ago
5.3 kB
3
Indexable
Never
import re
import time
from langchain_groq import ChatGroq
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
import pandas as pd
from tqdm import tqdm
import datetime

class ConversationAnalyzer:
    def __init__(self, api_key, model_name="mixtral-8x7b-32768", max_retries=5, base_wait_time=60):
        self.model = ChatGroq(model_name=model_name, api_key=api_key)
        self.prompt = ChatPromptTemplate.from_messages([
            ("system", "You are an expert in analyzing conversations. Your task is to determine the coherence of a given text. Focus on the overall coherence and natural flow of the text."),
            ("human", "Text: {text}\n\nAnalyze the coherence of this text and respond with a score between 0 (not at all coherent) and 1 (completely coherent), and explain your reasoning.")
        ])
        self.chain = self.prompt | self.model | StrOutputParser()
        self.max_retries = max_retries
        self.base_wait_time = base_wait_time

    def analyze_text(self, text):
        for attempt in range(self.max_retries):
            try:
                result = self.chain.invoke({"text": text})
                score_match = re.search(r"\b\d+(\.\d+)?\b", result)
                if score_match:
                    score = float(score_match.group(0))
                else:
                    raise ValueError("No numeric score found in the result")
                explanation = result.split("\n", 1)[1].strip() if "\n" in result else result.strip()
                return score, explanation
            except Exception as e:
                if "rate_limit_exceeded" in str(e):
                    wait_time = self.base_wait_time * (2 ** attempt)
                    print(f"Rate limit exceeded. Waiting for {wait_time} seconds before retrying.")
                    time.sleep(wait_time)
                elif attempt < self.max_retries - 1:
                    print(f"Error analyzing text (attempt {attempt + 1}/{self.max_retries}): {e}")
                    time.sleep(5)
                else:
                    print(f"All retry attempts failed. Error: {e}")
                    return 0.0, "Failed to analyze text after multiple attempts."

def find_coherent_texts(loader, analyzer, file_name, log_file):
    coherent_texts = []
    df = loader.load_data(file_name)
    total_texts = len(df)
    with open(log_file, 'a') as f:  # Changed to append mode
        for idx in tqdm(range(total_texts), desc=f"Processing {file_name}"):
            text = df.loc[idx, 'sentence']
            score, explanation = analyzer.analyze_text(text)
            if score > 0:
                coherent_texts.append({
                    'index': idx,
                    'text': text,
                    'score': score,
                    'explanation': explanation
                })
                log_entry = (f"Coherent text found with score {score}:\n"
                             f"Text: {text}\n"
                             f"Explanation: {explanation}\n"
                             f"{'-' * 50}\n")
                f.write(log_entry)
            if score >= 0.8:
                print(f"Coherent text found with score {score}:")
                print(f"Text: {text}")
                print(f"Explanation: {explanation}")
                print("-" * 50)
            
            # Save intermediate results every 100 iterations
            if (idx + 1) % 100 == 0:
                intermediate_df = pd.DataFrame(coherent_texts)
                intermediate_df.to_csv(f'intermediate_results_{idx+1}.csv', index=False)
                print(f"Saved intermediate results at iteration {idx+1}")
    
    return pd.DataFrame(coherent_texts)

def main():
    try:
        start = datetime.datetime.now()
        print(start)

        loader = DataLoader()
        analyzer = ConversationAnalyzer(api_key="your_api_key_here")  # Replace with your actual API key
        file_name = 'df_for_dori2.pkl'
        log_file = "logs.txt"
        results = find_coherent_texts(loader, analyzer, file_name, log_file)

        output_file = 'results.csv'
        results.to_csv(output_file, index=False)
        print("Saved results")

        # Load original dataframe to get additional information
        original_df = loader.load_data(file_name)

        # Add additional information to results
        results['path'] = results['index'].map(original_df['path'])
        results['start_cd'] = results['index'].map(original_df['start_cd'])
        results['end_cd'] = results['index'].map(original_df['end_cd'])
        results['times'] = results['index'].map(original_df['times'])

        # Reorder columns
        results = results[['index', 'path', 'text', 'start_cd', 'end_cd', 'times', 'score', 'explanation']]

        # Save results
        output_file = 'coherent_texts_results.csv'
        results.to_csv(output_file, index=False)
        print(f"Found {len(results)} coherent texts. Results saved to '{output_file}'")
    except Exception as e:
        print(f"Failed to run! Error: {e}")
    finally:
        end = datetime.datetime.now()
        print(end)
        print(f"Time that took: {end - start}")

if __name__ == "__main__":
    main()
Leave a Comment