Untitled
unknown
python
a year ago
7.2 kB
11
Indexable
import pandas as pd
import time
from langchain_groq import ChatGroq
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
class NameClassifier:
def __init__(self, api_key, model_name="mixtral-8x7b-32768", max_retries=5, base_wait_time=60):
self.llm = ChatGroq(model_name=model_name, api_key=api_key)
self.prompt = ChatPromptTemplate.from_messages([
("system", "You are an AI assistant specialized in identifying person names. Your task is to determine if the given text is a person's name or not."),
("human", "Is the following text a person's name? Respond with 'Yes' if it's a person's name, or 'Unclaimed' if it's not.\n\nText: {text}\n\nClassification:")
])
self.chain = (
{"text": RunnablePassthrough()}
| self.prompt
| self.llm
| StrOutputParser()
)
self.max_retries = max_retries
self.base_wait_time = base_wait_time
def classify_name(self, text):
for attempt in range(self.max_retries):
try:
result = self.chain.invoke({"text": text})
classification = result.strip().lower()
if classification == 'yes':
return text # Return the original name if it's classified as a person's name
else:
return "Unclaimed"
except Exception as e:
print(f"Error classifying name (attempt {attempt + 1}/{self.max_retries}): {e}")
if "rate_limit_exceeded" in str(e):
wait_time = self.base_wait_time * (2 ** attempt)
print(f"Rate limit exceeded. Waiting for {wait_time} seconds before retrying.")
time.sleep(wait_time)
elif attempt < self.max_retries - 1:
time.sleep(5)
else:
print(f"All retry attempts failed. Returning 'Unclaimed' for '{text}'.")
return "Unclaimed"
def process_name_list(name_list, api_key, model_name="mixtral-8x7b-32768"):
classifier = NameClassifier(api_key, model_name)
results = []
for i, name in enumerate(name_list):
print(f"Processing name {i+1}/{len(name_list)}: {name}")
classification = classifier.classify_name(name)
results.append({"Input": name, "Classification": classification})
# Save intermediate results every 100 iterations
if (i + 1) % 100 == 0:
df = pd.DataFrame(results)
df.to_csv(f'intermediate_name_classification_results_{i+1}.csv', index=False)
print(f"Saved intermediate results at iteration {i+1}")
return pd.DataFrame(results)
# Example usage
if __name__ == "__main__":
api_key = "your_api_key_here" # Replace with your actual API key
name_list = [
"John Doe",
"Apple Inc.",
"Jane Smith",
"New York City",
"William Shakespeare",
"Artificial Intelligence",
"Emily Johnson",
"Python Programming",
"Michael Jackson",
"United Nations"
]
result_df = process_name_list(name_list, api_key)
print(result_df)
result_df.to_csv('name_classification_results.csv', index=False)
print("Results saved to 'name_classification_results.csv'")
import pandas as pd
import time
from langchain_groq import ChatGroq
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
class NameClassifier:
def __init__(self, api_key, model_name="mixtral-8x7b-32768", max_retries=5, base_wait_time=60):
self.llm = ChatGroq(model_name=model_name, api_key=api_key)
self.prompt = ChatPromptTemplate.from_messages([
("system", "You are an AI assistant specialized in identifying person names. Your task is to determine if the given text is a person's name or not."),
("human", "Is the following text a person's name? Respond with 'Yes' if it's a person's name, or 'Unclaimed' if it's not.\n\nText: {text}\n\nClassification:")
])
self.chain = (
{"text": RunnablePassthrough()}
| self.prompt
| self.llm
| StrOutputParser()
)
self.max_retries = max_retries
self.base_wait_time = base_wait_time
def classify_name(self, text):
for attempt in range(self.max_retries):
try:
result = self.chain.invoke({"text": text})
classification = result.strip().lower()
if classification == 'yes':
return text # Return the original name if it's classified as a person's name
else:
return "Unclaimed"
except Exception as e:
print(f"Error classifying name (attempt {attempt + 1}/{self.max_retries}): {e}")
if "rate_limit_exceeded" in str(e):
wait_time = self.base_wait_time * (2 ** attempt)
print(f"Rate limit exceeded. Waiting for {wait_time} seconds before retrying.")
time.sleep(wait_time)
elif attempt < self.max_retries - 1:
time.sleep(5)
else:
print(f"All retry attempts failed. Returning 'Unclaimed' for '{text}'.")
return "Unclaimed"
def normalize_spaces(text):
"""Normalize spaces in a string, replacing multiple spaces with a single space."""
return ' '.join(text.split())
def process_name_list(name_list, api_key, model_name="mixtral-8x7b-32768"):
classifier = NameClassifier(api_key, model_name)
results = []
for i, name in enumerate(name_list):
print(f"Processing name {i+1}/{len(name_list)}: {name}")
classification = classifier.classify_name(name)
normalized_name = normalize_spaces(name)
results.append({
"Input": name,
"Classification": classification,
"Normalized Name": normalized_name
})
# Save intermediate results every 100 iterations
if (i + 1) % 100 == 0:
df = pd.DataFrame(results)
df.to_csv(f'intermediate_name_classification_results_{i+1}.csv', index=False)
print(f"Saved intermediate results at iteration {i+1}")
return pd.DataFrame(results)
# Example usage
if __name__ == "__main__":
api_key = "your_api_key_here" # Replace with your actual API key
name_list = [
"John Doe",
"Apple Inc.",
"Jane Smith",
"New York City",
"William Shakespeare",
"Artificial Intelligence",
"Emily Johnson",
"Python Programming",
"Michael Jackson",
"United Nations"
]
result_df = process_name_list(name_list, api_key)
print(result_df)
result_df.to_csv('name_classification_results.csv', index=False)
print("Results saved to 'name_classification_results.csv'")Editor is loading...
Leave a Comment