Untitled

 avatar
unknown
plain_text
20 days ago
4.2 kB
2
Indexable
import tensorflow as tf
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Embedding, LSTM, Dense, Dropout

# Ustaw losowe seed-y dla powtarzalności (opcjonalne)
np.random.seed(42)
tf.random.set_seed(42)

# --------------------------------------------------------------------------
# 1. Wczytanie danych z pliku CSV
# --------------------------------------------------------------------------
df = pd.read_csv("test.csv")  # Możesz dodać argument 'encoding' jeśli potrzebne
print("Pierwsze wiersze pliku:")
print(df.head())

df = df.drop(columns=["Unnamed: 0"], errors='ignore')

required_columns = ["text", "label"]
for col in required_columns:
    if col not in df.columns:
        raise ValueError(f"Brak wymaganej kolumny: {col} w pliku CSV.")

X_raw = df["text"].astype(str).values
y_raw = df["label"].values

num_classes = len(np.unique(y_raw))
print("Liczba unikalnych klas:", num_classes)

# --------------------------------------------------------------------------
# 2. Tokenizacja i wektoryzacja tekstu
# --------------------------------------------------------------------------
VOCAB_SIZE = 10000
tokenizer = Tokenizer(num_words=VOCAB_SIZE, oov_token="<OOV>")
tokenizer.fit_on_texts(X_raw)

sequences = tokenizer.texts_to_sequences(X_raw)

MAX_LEN = 50
X = pad_sequences(sequences, maxlen=MAX_LEN, padding='post', truncating='post')

y = np.array(y_raw, dtype=np.int32)

# --------------------------------------------------------------------------
# 3. Podział na zbiór treningowy i testowy
# --------------------------------------------------------------------------
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

print("Rozmiar zbioru treningowego:", X_train.shape[0])
print("Rozmiar zbioru testowego:", X_test.shape[0])

# --------------------------------------------------------------------------
# 4. Budowa modelu (multi-class single-label)
# --------------------------------------------------------------------------
model = Sequential([
    Input(shape=(MAX_LEN,)),
    Embedding(input_dim=VOCAB_SIZE, output_dim=64),
    LSTM(128),
    Dropout(0.3),
    Dense(num_classes, activation='softmax')
])

model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

model.summary()

# --------------------------------------------------------------------------
# 5. Trenowanie
# --------------------------------------------------------------------------
EPOCHS = 5
BATCH_SIZE = 32

history = model.fit(
    X_train, y_train,
    validation_data=(X_test, y_test),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE
)

# --------------------------------------------------------------------------
# 6. Wyniki i ocena
# --------------------------------------------------------------------------
# Predykcje na zbiorze testowym
predictions = model.predict(X_test)
predicted_classes = np.argmax(predictions, axis=1)

# Wyliczenie metryk
accuracy = accuracy_score(y_test, predicted_classes)
report = classification_report(y_test, predicted_classes, target_names=[f"Class {i}" for i in range(num_classes)])

print("\nDokładność na zbiorze testowym:", accuracy)
print("\nRaport klasyfikacji:")
print(report)

# Podgląd kilku wyników predykcji
SAMPLES_TO_PREDICT = 5
for i in range(SAMPLES_TO_PREDICT):
    decoded_text = tokenizer.sequences_to_texts([X_test[i]])[0]
    true_class = y_test[i]
    pred_class = predicted_classes[i]
    probability = predictions[i][pred_class]
    print("\n===== Przykład #%d =====" % (i+1))
    print(f"Tekst: {decoded_text}")
    print(f"Prawdziwa etykieta: {true_class}")
    print(f"Przewidywana etykieta: {pred_class} (prawdopodobieństwo: {probability:.4f})")
    print("=========================")
Leave a Comment