Untitled

mail@pastecode.io avatar
unknown
plain_text
a month ago
3.4 kB
2
Indexable
Never
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;

namespace Completed
{
    public class PlayerAgent : Agent
    {
        public enum PenaltyStyle
        {
            Reward,
            Termination
        }

        private Player player;
        [SerializeField] private GameManager gameManager;
        public bool banana = false;
        public PenaltyStyle penaltyStyle;
        public float dmg;
        private Rigidbody2D rb2d;
        private SpriteRenderer sr;

        void Start()
        {
            Application.runInBackground = true;

            player = GetComponent<Player>();
            rb2d = GetComponent<Rigidbody2D>();
            sr = GetComponent<SpriteRenderer>();

            if (player == null)
            {
                Debug.LogError("Player component not found on the GameObject.");
            }

            if (gameManager == null)
            {
                Debug.LogError("GameManager not assigned in the inspector.");
            }

            Debug.Log("Player agent has been initialized.");
        }

        public void OnEpisodeEnd()
        {
            Academy.Instance.StatsRecorder.Add("Phase Reached", gameManager.phase);
        }

        public override void CollectObservations(VectorSensor sensor)
        {
            sensor.AddObservation(transform.position.x);
            sensor.AddObservation(transform.position.y);
            sensor.AddObservation(transform.rotation);
        }

        public override void OnActionReceived(ActionBuffers actionBuffers)
        {
            int action_chosen = actionBuffers.DiscreteActions[0];
            Debug.Log("Action chosen = " + action_chosen + ", " + actionBuffers.DiscreteActions[1]);
            int rotate = actionBuffers.DiscreteActions[0];
            int strafe = actionBuffers.DiscreteActions[1];
            player.Movement_V1(rotate, strafe);
        }

        public override void Heuristic(in ActionBuffers actionsOut)
        {
            var actions = actionsOut.DiscreteActions;
            //Debug.Log("actions: " + actions[0] + " " + actions[1]);
        }

        #region handle
        public void FoundPowerup()
        {
            AddReward(0.2f);
        }

        public void LostShield()
        {
            AddReward(-0.2f);
        }

        public void DestroyBoss()
        {
            AddReward(0.2f);
        }

        public void DestroyBossSpawn()
        {
            AddReward(0.4f);
        }

        public void GameLose()
        {
            AddReward(-1.0f);
        }
        public void HealthLoss()
        {
            AddReward(-0.2f);
        }

        public void HealthGain()
        {
            AddReward(0.3f);
        }

        public void DestroySpawner()
        {
            AddReward(0.4f);
        }
        public void DestroyEnemy()
        {
            AddReward(0.1f);
        }

        public void AdvancedHardPhase()
        {
            AddReward(0.5f);
        }

        public void AdvancedPhase()
        {
            AddReward(0.2f);
        }

        public void WonGame()
        {
            AddReward(1.0f);
        }




        #endregion
    }
}
Leave a Comment