Untitled

 avatar
unknown
plain_text
a month ago
17 kB
4
Indexable
from airflow import DAG
from airflow.utils.task_group import TaskGroup  # Add this import
from airflow.operators.python import PythonOperator
from airflow.decorators import task
from airflow.models.dag import DAG
from airflow.models import Variable
from airflow.exceptions import AirflowTaskTimeout
from airflow.models.baseoperator import chain

from datetime import datetime, timedelta
import sys
import os
import logging
import timeout_decorator
import time
import json
import pandas as pd
from concurrent.futures import ThreadPoolExecutor
import random

sys.path.extend(
    [
        "/app/scraper/src/fetchers/reference_data/tradingview",
        "/app/scraper/src/database/reference_data",
    ]
)

from sector_and_industries import *

# from app.scraper.src.fetchers.reference_data.tradingview.sector_and_industries import *
from cassandra_client import CassandraClient  # pylint: disable=import-error

default_args = {
    "owner": "airflow",
    "depends_on_past": False,
    "start_date": datetime(2024, 2, 17),
    "retries": 3,
    "retry_delay": timedelta(minutes=5),
    "execution_timeout": timedelta(minutes=60),
}


# ==============================
# 🔹 Step 0: load list countries
# ==============================
# Function to get country list
COUNTRIES_JSON_PATH = "/app/scraper/src/database/reference_data/scraping_raw_json/tradingview/countries_with_flags.json"
# COUNTRIES_JSON_PATH = "scraper/src/database/reference_data/scraping_raw_json/tradingview/countries_with_flags.json"


def clean_country_name(country: str) -> str:
    """Convert country name to valid task ID format"""
    # Replace spaces and special characters with underscores
    return (
        country.lower()
        .replace(" ", "_")
        .replace("-", "_")
        .replace("(", "")
        .replace(")", "")
    )


def get_all_countries():
    with open(COUNTRIES_JSON_PATH, "r", encoding="utf-8") as file:
        data = json.load(file)

    all_countries = [
        country["country"].lower() for reg in data for country in reg["countries"]
    ]
    print(all_countries)
    return all_countries


all_countries = get_all_countries()

# ==============================
# 🔹 Step 1: get sectors and industries
# ==============================
datasets = ["sectors", "industries"]


def fetch_tradingview_sector_and_industry_by_country(country, dataset):
    cassandra_client = CassandraClient()
    cassandra_client.connect()
    cassandra_client.create_keyspace()
    country_source = country.lower().replace(" ", "-")
    if dataset == "sectors":
        df = get_tradingview_sectors_by_country(country=country_source)
    elif dataset == "industries":
        df = get_tradingview_industries_by_country(country=country_source)

    view_data(df)
    if df is None or df.empty:
        raise ValueError("❌ ERROR: df is None or empty! Check the data source.")

    table_name = f"tradingview_{dataset}"
    cassandra_client.insert_df_to_cassandra(df, table_name)
    cassandra_client.close()

    # ti = kwargs["ti"]
    # ti.xcom_push(key=f"{country}_{dataset}_done", value=True)

    return f"Fetch successfully {dataset} of {country}!!!!"


# ==============================
# 🔹 Step 2: get components
# ==============================
# @task(execution_timeout=timedelta(minutes=30), do_xcom_push=False)
def fetch_tradingview_sector_and_industry_components_by_country(country, dataset):

    cassandra_client = CassandraClient()
    cassandra_client.connect()
    cassandra_client.create_keyspace()

    # ti = kwargs["ti"]
    # country_done = ti.xcom_pull(
    #     task_ids=f"fetch_{clean_country_name(country)}_{dataset}",
    #     key=f"{country}_{dataset}_done",
    # )
    # print(country_done)
    # if not country_done:
    #     print(f"⚠️ {country} does not have sector/industry data, waiting...")
    #     return f"⏳ Waiting for {country} data..."

    table_name = f"tradingview_{dataset}"
    print(table_name)
    query = f"SELECT * FROM {table_name} WHERE country = %s ALLOW FILTERING "
    components_df = cassandra_client.query_to_dataframe(query, (str(country),))

    if components_df.empty:
        print(f"⚠️ No data found for country: {country} in {dataset}. Skipping...")
        cassandra_client.close()
        return f"No data for {country}, skipping..."

    view_data(components_df)
    components_links = components_df["component_url"].tolist()
    components_arr = []

    with ThreadPoolExecutor(max_workers=5) as executor:
        components_arr = list(
            executor.map(
                lambda link: (
                    print(f"Fetching components from {link}")
                    or time.sleep(random.randint(1, 10))
                    or get_tradingview_sectors_industries_components(link).assign(
                        component_url=link
                    )
                    if link
                    else pd.DataFrame()
                ),
                components_links,
            )
        )

    components_df = pd.concat(components_arr, ignore_index=True)
    view_data(components_df)

    table_name = f"tradingview_icb_components"
    cassandra_client.insert_df_to_cassandra(components_df, table_name)
    cassandra_client.close()

    return f"fetch successfully {dataset} components of {country}!!!!"


def fetch_components_by_list(country, dataset, component_urls):
    """Fetch components for a subset of component_urls"""
    cassandra_client = CassandraClient()
    cassandra_client.connect()
    cassandra_client.create_keyspace()

    components_arr = []
    with ThreadPoolExecutor(max_workers=5) as executor:
        components_arr = list(
            executor.map(
                lambda link: (
                    print(f"Fetching components from {link}")
                    or time.sleep(random.randint(1, 10))
                    or get_tradingview_sectors_industries_components(link).assign(
                        component_url=link
                    )
                    if link
                    else pd.DataFrame()
                ),
                component_urls,
            )
        )

    components_df = pd.concat(components_arr, ignore_index=True)
    if not components_df.empty:
        table_name = f"tradingview_icb_components"
        cassandra_client.insert_df_to_cassandra(components_df, table_name)

    cassandra_client.close()
    return (
        f"Fetched components for {len(component_urls)} links in {dataset} of {country}"
    )


# ==============================
# 🔹 Step 4: concat and save
# ==============================
def fetch_components_batch(links, country, dataset, batch_index):
    """Lấy components cho một lô links"""
    try:
        cassandra_client = CassandraClient()
        cassandra_client.connect()
        cassandra_client.create_keyspace()

        components_arr = []
        with ThreadPoolExecutor(max_workers=5) as executor:
            batch_results = list(
                executor.map(
                    lambda link: (
                        logging.info(f"Fetching components from {link}"),
                        time.sleep(random.randint(1, 3)),  # Giảm thời gian sleep
                        (
                            get_tradingview_sectors_industries_components(link).assign(
                                component_url=link
                            )
                            if link
                            else pd.DataFrame()
                        ),
                    )[-1],
                    links,
                )
            )
        components_arr.extend(batch_results)

        if not components_arr:
            logging.warning(
                f"No components fetched for batch {batch_index} of {dataset} in {country}"
            )
            cassandra_client.close()
            return f"No components for batch {batch_index} of {dataset} in {country}"

        components_df = pd.concat(components_arr, ignore_index=True)
        if components_df.empty:
            logging.warning(
                f"Empty components DataFrame for batch {batch_index} of {dataset} in {country}"
            )
            cassandra_client.close()
            return f"Empty components for batch {batch_index} of {dataset} in {country}"

        table_name = "tradingview_icb_components"
        cassandra_client.insert_df_to_cassandra(components_df, table_name)
        cassandra_client.close()

        return components_df
    except Exception as e:
        logging.error(
            f"Error fetching components for batch {batch_index} of {dataset} in {country}: {str(e)}"
        )
        raise


def concat_and_save(**kwargs):
    ti = kwargs["ti"]
    all_sectors = []
    all_sectors_components = []
    all_industries = []
    all_industries_components = []

    for country in all_countries:
        country = clean_country_name(country)
        sector_data = ti.xcom_pull(
            task_ids=f"fetch_sectors_and_industries_tasks.fetch_{clean_country_name(country)}_sectors",
            key=f"{clean_country_name(country)}_{dataset}",
        )
        sector_component_data = ti.xcom_pull(
            task_ids=f"fetch_components_tasks.fetch_{clean_country_name(country)}_sectors_components",
            key=f"{clean_country_name(country)}_{dataset}_components",
        )
        industry_data = ti.xcom_pull(
            task_ids=f"fetch_sectors_and_industries_tasks.fetch_{clean_country_name(country)}_industries",
            key=f"{clean_country_name(country)}_{dataset}",
        )
        industry_component_data = ti.xcom_pull(
            task_ids=f"fetch_components_tasks.fetch_{clean_country_name(country)}_industries_components",
            key=f"{clean_country_name(country)}_{dataset}_components",
        )

        all_sectors.append(sector_data)
        all_sectors_components.append(sector_component_data)
        all_industries.append(industry_data)
        all_industries_components.append(industry_component_data)

    final_sectors = pd.concat(all_sectors, ignore_index=True)
    final_sectors_components = pd.concat(all_sectors_components, ignore_index=True)
    final_industries = pd.concat(all_industries, ignore_index=True)
    final_industries_components = pd.concat(
        all_industries_components, ignore_index=True
    )

    print(final_sectors)
    print(final_sectors_components)
    print(final_industries)
    print(final_industries_components)


# Define DAG
with DAG(
    "trading_view_get_sectors_and_industries",
    default_args=default_args,
    # schedule_interval=timedelta(minutes=90),
    schedule_interval=None,
    catchup=False,
    concurrency=32,
    max_active_tasks=32,
) as dag:

    list_countries_task = PythonOperator(
        task_id="get_list_countries",
        python_callable=get_all_countries,
    )

    # @task
    # def process_country(country):
    #     with TaskGroup(
    #         group_id=f"fetch_sector_industry_of_{clean_country_name(country)}"
    #     ) as tg:

    #         sectors_task = fetch_tradingview_sector_and_industry_by_country(
    #             country, "sectors"
    #         )
    #         industries_task = fetch_tradingview_sector_and_industry_by_country(
    #             country, "industries"
    #         )

    #         def split_urls(urls, num_chunks=3):
    #             if not urls:
    #                 return [[]] * num_chunks
    #             chunk_size = max(1, len(urls) // num_chunks)
    #             return [
    #                 urls[i : i + chunk_size] for i in range(0, len(urls), chunk_size)
    #             ][:num_chunks]

    #         @task
    #         def prepare_sector_components(sectors_result):
    #             if (
    #                 sectors_result
    #                 and isinstance(sectors_result, str)
    #                 and "Fetch successfully" in sectors_result
    #             ):
    #                 cassandra_client = CassandraClient()
    #                 cassandra_client.connect()
    #                 query = "SELECT component_url FROM tradingview_sectors WHERE country = %s ALLOW FILTERING"
    #                 df = cassandra_client.query_to_dataframe(query, (country,))
    #                 cassandra_client.close()
    #                 return split_urls(df["component_url"].tolist())
    #             return [[]] * 3

    #         @task
    #         def prepare_industry_components(industries_result):
    #             if (
    #                 industries_result
    #                 and isinstance(industries_result, str)
    #                 and "Fetch successfully" in industries_result
    #             ):
    #                 cassandra_client = CassandraClient()
    #                 cassandra_client.connect()
    #                 query = "SELECT component_url FROM tradingview_industries WHERE country = %s ALLOW FILTERING"
    #                 df = cassandra_client.query_to_dataframe(query, (country,))
    #                 cassandra_client.close()
    #                 return split_urls(df["component_url"].tolist())
    #             return [[]] * 3

    #         sector_url_chunks = prepare_sector_components(sectors_task)
    #         industry_url_chunks = prepare_industry_components(industries_task)

    #         sector_component_tasks = [
    #             fetch_components_by_list.override(
    #                 task_id=f"fetch_sector_components_chunk_{i+1}"
    #             )(country, "sectors", sector_url_chunks[i])
    #             for i in range(3)
    #         ]

    #         industry_component_tasks = [
    #             fetch_components_by_list.override(
    #                 task_id=f"fetch_industry_components_chunk_{i+1}"
    #             )(country, "industries", industry_url_chunks[i])
    #             for i in range(3)
    #         ]

    #         sectors_task >> sector_url_chunks >> sector_component_tasks
    #         industries_task >> industry_url_chunks >> industry_component_tasks

    #     return f"Completed fetch sector and industry of {country}"

    # process_tasks = process_country.expand(country=all_countries)

    fetch_sectors_and_industries_tasks = []
    fetch_components_tasks = []

    with TaskGroup(
        "fetch_sectors_and_industries_tasks"
    ) as fetch_sectors_and_industries_group:
        for country in all_countries:
            print(country)
            for dataset in datasets:
                PythonOperator(
                    task_id=f"fetch_{clean_country_name(country)}_{dataset}",
                    python_callable=fetch_tradingview_sector_and_industry_by_country,
                    op_args=[country, dataset],
                )

    country_groups = []
    batch_size = 5

    for country in all_countries:
        cleaned_country = clean_country_name(country)
        with TaskGroup(f"fetch_{cleaned_country}_tasks") as country_group:
            for dataset in datasets:
                fetch_components_urls_task = PythonOperator(
                    task_id=f"fetch_{cleaned_country}_{dataset}_components_urls",
                    python_callable=fetch_tradingview_sector_and_industry_components_by_country,
                    op_args=[country, dataset],
                )

                components_urls = fetch_components_urls_task.execute(context={})
                if not components_urls:
                    continue

                batch_groups = []
                for batch_index, i in enumerate(
                    range(0, len(components_urls), batch_size)
                ):
                    batch_links = components_urls[i : i + batch_size]
                    with TaskGroup(
                        f"fetch_{cleaned_country}_{dataset}_batch_{batch_index}"
                    ) as batch_group:
                        batch_task = PythonOperator(
                            task_id=f"fetch_{cleaned_country}_{dataset}_components_batch_{batch_index}",
                            python_callable=fetch_components_batch,
                            op_args=[batch_links, country, dataset, batch_index],
                        )
                        batch_groups.append(batch_group)

                if batch_groups:
                    fetch_components_urls_task >> batch_groups

        country_groups.append(country_group)

    list_countries_task >> fetch_sectors_and_industries_group >> country_groups


def cleanup():
    logging.info("Closing Cassandra connection.")
    cassandra_client.close()


dag.on_success_callback = cleanup
dag.on_failure_callback = cleanup
Editor is loading...
Leave a Comment